mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-07 04:45:44 +00:00
Merge branch 'main' into bugfix/model-type
This commit is contained in:
commit
41948027df
205 changed files with 22222 additions and 10702 deletions
|
@ -5,10 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, Protocol, runtime_checkable
|
||||
from typing import Annotated, ClassVar, Literal, Protocol, runtime_checkable
|
||||
|
||||
from fastapi import File, Form, Response, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
@ -49,6 +49,23 @@ class OpenAIFileObject(BaseModel):
|
|||
purpose: OpenAIFilePurpose
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ExpiresAfter(BaseModel):
|
||||
"""
|
||||
Control expiration of uploaded files.
|
||||
|
||||
Params:
|
||||
- anchor, must be "created_at"
|
||||
- seconds, must be int between 3600 and 2592000 (1 hour to 30 days)
|
||||
"""
|
||||
|
||||
MIN: ClassVar[int] = 3600 # 1 hour
|
||||
MAX: ClassVar[int] = 2592000 # 30 days
|
||||
|
||||
anchor: Literal["created_at"]
|
||||
seconds: int = Field(..., ge=3600, le=2592000)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ListOpenAIFileResponse(BaseModel):
|
||||
"""
|
||||
|
@ -92,6 +109,9 @@ class Files(Protocol):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
# TODO: expires_after is producing strange openapi spec, params are showing up as a required w/ oneOf being null
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Upload a file that can be used across various endpoints.
|
||||
|
@ -99,6 +119,7 @@ class Files(Protocol):
|
|||
The file upload should be a multipart form request with:
|
||||
- file: The File object (not file name) to be uploaded.
|
||||
- purpose: The intended purpose of the uploaded file.
|
||||
- expires_after: Optional form values describing expiration for the file. Expected expires_after[anchor] = "created_at", expires_after[seconds] = <int>. Seconds must be between 3600 and 2592000 (1 hour to 30 days).
|
||||
|
||||
:param file: The uploaded file object containing content and metadata (filename, content_type, etc.).
|
||||
:param purpose: The intended purpose of the uploaded file (e.g., "assistants", "fine-tune").
|
||||
|
|
|
@ -80,7 +80,7 @@ def get_provider_dependencies(
|
|||
normal_deps = []
|
||||
special_deps = []
|
||||
for package in deps:
|
||||
if "--no-deps" in package or "--index-url" in package:
|
||||
if any(f in package for f in ["--no-deps", "--index-url", "--extra-index-url"]):
|
||||
special_deps.append(package)
|
||||
else:
|
||||
normal_deps.append(package)
|
||||
|
|
|
@ -284,7 +284,15 @@ async def instantiate_providers(
|
|||
if provider.provider_id is None:
|
||||
continue
|
||||
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
try:
|
||||
deps = {a: impls[a] for a in provider.spec.api_dependencies}
|
||||
except KeyError as e:
|
||||
missing_api = e.args[0]
|
||||
raise RuntimeError(
|
||||
f"Failed to resolve '{provider.spec.api.value}' provider '{provider.provider_id}' of type '{provider.spec.provider_type}': "
|
||||
f"required dependency '{missing_api.value}' is not available. "
|
||||
f"Please add a '{missing_api.value}' provider to your configuration or check if the provider is properly configured."
|
||||
) from e
|
||||
for a in provider.spec.optional_api_dependencies:
|
||||
if a in impls:
|
||||
deps[a] = impls[a]
|
||||
|
|
|
@ -755,7 +755,7 @@ class InferenceRouter(Inference):
|
|||
choices_data[idx] = {
|
||||
"content_parts": [],
|
||||
"tool_calls_builder": {},
|
||||
"finish_reason": None,
|
||||
"finish_reason": "stop",
|
||||
"logprobs_content_parts": [],
|
||||
}
|
||||
current_choice_data = choices_data[idx]
|
||||
|
|
|
@ -105,12 +105,12 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
|
||||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
if hasattr(obj, "provider_id"):
|
||||
# Do not register models on disabled providers
|
||||
if not obj.provider_id or obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||
|
@ -225,7 +225,10 @@ def replace_env_vars(config: Any, path: str = "") -> Any:
|
|||
|
||||
try:
|
||||
result = re.sub(pattern, get_env_var, config)
|
||||
return _convert_string_to_proper_type(result)
|
||||
# Only apply type conversion if substitution actually happened
|
||||
if result != config:
|
||||
return _convert_string_to_proper_type(result)
|
||||
return result
|
||||
except EnvVarError as e:
|
||||
raise EnvVarError(e.var_name, e.path) from None
|
||||
|
||||
|
|
|
@ -34,7 +34,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::huggingface-cpu
|
||||
- provider_type: inline::torchtune-cpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -156,13 +156,10 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: huggingface-cpu
|
||||
provider_type: inline::huggingface-cpu
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/ci-tests/dpo_output
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
---
|
||||
orphan: true
|
||||
---
|
||||
# Meta Reference Distribution
|
||||
# Meta Reference GPU Distribution
|
||||
|
||||
```{toctree}
|
||||
:maxdepth: 2
|
||||
|
@ -29,7 +29,7 @@ The following environment variables can be configured:
|
|||
|
||||
## Prerequisite: Downloading Models
|
||||
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](../../references/llama_cli_reference/download_models.md) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
||||
|
||||
```
|
||||
$ llama model list --downloaded
|
||||
|
|
|
@ -134,6 +134,11 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: nvidia/vila
|
||||
provider_id: nvidia
|
||||
provider_model_id: nvidia/vila
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
|
|
|
@ -43,7 +43,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
|
|||
"openai",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="openai/gpt-4o",
|
||||
provider_model_id="gpt-4o",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
|
@ -53,7 +53,7 @@ def get_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderMo
|
|||
"anthropic",
|
||||
[
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/claude-3-5-sonnet-latest",
|
||||
provider_model_id="claude-3-5-sonnet-latest",
|
||||
model_type=ModelType.llm,
|
||||
)
|
||||
],
|
||||
|
@ -206,13 +206,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
uri="huggingface://datasets/llamastack/math_500?split=test",
|
||||
),
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="bfcl",
|
||||
purpose=DatasetPurpose.eval_messages_answer,
|
||||
source=URIDataSource(
|
||||
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
|
||||
),
|
||||
),
|
||||
DatasetInput(
|
||||
dataset_id="ifeval",
|
||||
purpose=DatasetPurpose.eval_messages_answer,
|
||||
|
@ -250,11 +243,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
dataset_id="math_500",
|
||||
scoring_functions=["basic::regex_parser_math_response"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-bfcl",
|
||||
dataset_id="bfcl",
|
||||
scoring_functions=["basic::bfcl"],
|
||||
),
|
||||
BenchmarkInput(
|
||||
benchmark_id="meta-reference-ifeval",
|
||||
dataset_id="ifeval",
|
||||
|
|
|
@ -136,14 +136,14 @@ inference_store:
|
|||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/inference_store.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: openai/gpt-4o
|
||||
model_id: gpt-4o
|
||||
provider_id: openai
|
||||
provider_model_id: openai/gpt-4o
|
||||
provider_model_id: gpt-4o
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: anthropic/claude-3-5-sonnet-latest
|
||||
model_id: claude-3-5-sonnet-latest
|
||||
provider_id: anthropic
|
||||
provider_model_id: anthropic/claude-3-5-sonnet-latest
|
||||
provider_model_id: claude-3-5-sonnet-latest
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: gemini/gemini-1.5-flash
|
||||
|
@ -188,12 +188,6 @@ datasets:
|
|||
uri: huggingface://datasets/llamastack/math_500?split=test
|
||||
metadata: {}
|
||||
dataset_id: math_500
|
||||
- purpose: eval/messages-answer
|
||||
source:
|
||||
type: uri
|
||||
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
|
||||
metadata: {}
|
||||
dataset_id: bfcl
|
||||
- purpose: eval/messages-answer
|
||||
source:
|
||||
type: uri
|
||||
|
@ -228,11 +222,6 @@ benchmarks:
|
|||
- basic::regex_parser_math_response
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-math-500
|
||||
- dataset_id: bfcl
|
||||
scoring_functions:
|
||||
- basic::bfcl
|
||||
metadata: {}
|
||||
benchmark_id: meta-reference-bfcl
|
||||
- dataset_id: ifeval
|
||||
scoring_functions:
|
||||
- basic::ifeval
|
||||
|
|
|
@ -35,7 +35,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::torchtune-gpu
|
||||
- provider_type: inline::huggingface-gpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -156,10 +156,13 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: torchtune-gpu
|
||||
provider_type: inline::torchtune-gpu
|
||||
- provider_id: huggingface-gpu
|
||||
provider_type: inline::huggingface-gpu
|
||||
config:
|
||||
checkpoint_format: meta
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/starter-gpu/dpo_output
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -17,6 +17,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||
|
||||
template.providers["post_training"] = [
|
||||
BuildProvider(provider_type="inline::torchtune-gpu"),
|
||||
BuildProvider(provider_type="inline::huggingface-gpu"),
|
||||
]
|
||||
return template
|
||||
|
|
|
@ -35,7 +35,7 @@ distribution_spec:
|
|||
telemetry:
|
||||
- provider_type: inline::meta-reference
|
||||
post_training:
|
||||
- provider_type: inline::huggingface-cpu
|
||||
- provider_type: inline::torchtune-cpu
|
||||
eval:
|
||||
- provider_type: inline::meta-reference
|
||||
datasetio:
|
||||
|
|
|
@ -156,13 +156,10 @@ providers:
|
|||
sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/trace_store.db
|
||||
otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=}
|
||||
post_training:
|
||||
- provider_id: huggingface-cpu
|
||||
provider_type: inline::huggingface-cpu
|
||||
- provider_id: torchtune-cpu
|
||||
provider_type: inline::torchtune-cpu
|
||||
config:
|
||||
checkpoint_format: huggingface
|
||||
distributed_backend: null
|
||||
device: cpu
|
||||
dpo_output_dir: ~/.llama/distributions/starter/dpo_output
|
||||
checkpoint_format: meta
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
|
|
|
@ -120,7 +120,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
],
|
||||
"agents": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"telemetry": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"post_training": [BuildProvider(provider_type="inline::huggingface-cpu")],
|
||||
"post_training": [BuildProvider(provider_type="inline::torchtune-cpu")],
|
||||
"eval": [BuildProvider(provider_type="inline::meta-reference")],
|
||||
"datasetio": [
|
||||
BuildProvider(provider_type="remote::huggingface"),
|
||||
|
|
|
@ -86,11 +86,16 @@ class LocalfsFilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
) -> OpenAIFileObject:
|
||||
"""Upload a file that can be used across various endpoints."""
|
||||
if not self.sql_store:
|
||||
raise RuntimeError("Files provider not initialized")
|
||||
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
raise NotImplementedError("File expiration is not supported by this provider")
|
||||
|
||||
file_id = self._generate_file_id()
|
||||
file_path = self._get_file_path(file_id)
|
||||
|
||||
|
|
|
@ -22,7 +22,6 @@ from llama_stack.providers.utils.common.data_schema_validator import (
|
|||
)
|
||||
|
||||
from .config import BasicScoringConfig
|
||||
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
|
||||
from .scoring_fn.docvqa_scoring_fn import DocVQAScoringFn
|
||||
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
|
||||
from .scoring_fn.ifeval_scoring_fn import IfEvalScoringFn
|
||||
|
@ -37,7 +36,6 @@ FIXED_FNS = [
|
|||
SubsetOfScoringFn,
|
||||
RegexParserScoringFn,
|
||||
RegexParserMathResponseScoringFn,
|
||||
BFCLScoringFn,
|
||||
IfEvalScoringFn,
|
||||
DocVQAScoringFn,
|
||||
]
|
||||
|
|
|
@ -1,93 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
|
||||
|
||||
from ..utils.bfcl.ast_parser import decode_ast
|
||||
from ..utils.bfcl.checker import ast_checker, is_empty_output
|
||||
from .fn_defs.bfcl import bfcl
|
||||
|
||||
|
||||
def postprocess(x: dict[str, Any], test_category: str) -> dict[str, Any]:
|
||||
contain_func_call = False
|
||||
error = None
|
||||
error_type = None
|
||||
checker_result = {}
|
||||
try:
|
||||
prediction = decode_ast(x["generated_answer"], x["language"]) or ""
|
||||
contain_func_call = True
|
||||
# if not is_function_calling_format_output(prediction):
|
||||
if is_empty_output(prediction):
|
||||
contain_func_call = False
|
||||
error = "Did not output in the specified format. Note: the model_result is wrapped in a string to ensure json serializability."
|
||||
error_type = "ast_decoder:decoder_wrong_output_format"
|
||||
else:
|
||||
checker_result = ast_checker(
|
||||
json.loads(x["function"]),
|
||||
prediction,
|
||||
json.loads(x["ground_truth"]),
|
||||
x["language"],
|
||||
test_category=test_category,
|
||||
model_name="",
|
||||
)
|
||||
except Exception as e:
|
||||
prediction = ""
|
||||
error = f"Invalid syntax. Failed to decode AST. {str(e)}"
|
||||
error_type = "ast_decoder:decoder_failed"
|
||||
return {
|
||||
"prediction": prediction,
|
||||
"contain_func_call": contain_func_call,
|
||||
"valid": checker_result.get("valid", False),
|
||||
"error": error or checker_result.get("error", ""),
|
||||
"error_type": error_type or checker_result.get("error_type", ""),
|
||||
}
|
||||
|
||||
|
||||
def gen_valid(x: dict[str, Any]) -> dict[str, float]:
|
||||
return {"valid": x["valid"]}
|
||||
|
||||
|
||||
def gen_relevance_acc(x: dict[str, Any]) -> dict[str, float]:
|
||||
# This function serves for both relevance and irrelevance tests, which share the exact opposite logic.
|
||||
# If `test_category` is "irrelevance", the model is expected to output no function call.
|
||||
# No function call means either the AST decoding fails (a error message is generated) or the decoded AST does not contain any function call (such as a empty list, `[]`).
|
||||
# If `test_category` is "relevance", the model is expected to output to a function call, and empty list doesn't count as a function call.
|
||||
acc = not x["contain_func_call"] if "irrelevance" in x["id"] else x["contain_func_call"]
|
||||
return {"valid": float(acc)}
|
||||
|
||||
|
||||
class BFCLScoringFn(RegisteredBaseScoringFn):
|
||||
"""
|
||||
A scoring_fn for BFCL
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.supported_fn_defs_registry = {
|
||||
bfcl.identifier: bfcl,
|
||||
}
|
||||
|
||||
async def score_row(
|
||||
self,
|
||||
input_row: dict[str, Any],
|
||||
scoring_fn_identifier: str | None = "bfcl",
|
||||
scoring_params: ScoringFnParams | None = None,
|
||||
) -> ScoringResultRow:
|
||||
test_category = re.sub(r"_[0-9_-]+$", "", input_row["id"])
|
||||
score_result = postprocess(input_row, test_category)
|
||||
if test_category in {"irrelevance", "live_relevance", "live_irrelevance"}:
|
||||
score = gen_relevance_acc(score_result)["valid"]
|
||||
else:
|
||||
score = gen_valid(score_result)["valid"]
|
||||
return {
|
||||
"score": float(score),
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
bfcl = ScoringFn(
|
||||
identifier="basic::bfcl",
|
||||
description="BFCL complex scoring",
|
||||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="bfcl",
|
||||
params=BasicScoringFnParams(aggregation_functions=[AggregationFunctionType.accuracy]),
|
||||
)
|
|
@ -1,5 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -1,296 +0,0 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import ast
|
||||
|
||||
from .tree_sitter import get_parser
|
||||
|
||||
|
||||
def parse_java_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("java")
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
|
||||
if root_node.has_error:
|
||||
raise Exception("Error parsing java the source code.")
|
||||
|
||||
def get_text(node):
|
||||
"""Returns the text represented by the node."""
|
||||
return source_code[node.start_byte : node.end_byte]
|
||||
|
||||
def traverse_node(node, nested=False):
|
||||
if node.type == "string_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding quotes from string literals
|
||||
return get_text(node)[1:-1]
|
||||
elif node.type == "character_literal":
|
||||
if nested:
|
||||
return get_text(node)
|
||||
# Strip surrounding single quotes from character literals
|
||||
return get_text(node)[1:-1]
|
||||
"""Traverse the node to collect texts for complex structures."""
|
||||
if node.type in [
|
||||
"identifier",
|
||||
"class_literal",
|
||||
"type_identifier",
|
||||
"method_invocation",
|
||||
]:
|
||||
return get_text(node)
|
||||
elif node.type == "array_creation_expression":
|
||||
# Handle array creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
value_node = node.child_by_field_name("value")
|
||||
type_text = traverse_node(type_node, True)
|
||||
value_text = traverse_node(value_node, True)
|
||||
return f"new {type_text}[]{value_text}"
|
||||
elif node.type == "object_creation_expression":
|
||||
# Handle object creation expression specifically
|
||||
type_node = node.child_by_field_name("type")
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
type_text = traverse_node(type_node, True)
|
||||
if arguments_node:
|
||||
# Process each argument carefully, avoiding unnecessary punctuation
|
||||
argument_texts = []
|
||||
for child in arguments_node.children:
|
||||
if child.type not in [
|
||||
",",
|
||||
"(",
|
||||
")",
|
||||
]: # Exclude commas and parentheses
|
||||
argument_text = traverse_node(child, True)
|
||||
argument_texts.append(argument_text)
|
||||
arguments_text = ", ".join(argument_texts)
|
||||
return f"new {type_text}({arguments_text})"
|
||||
else:
|
||||
return f"new {type_text}()"
|
||||
elif node.type == "set":
|
||||
# Handling sets specifically
|
||||
items = [traverse_node(n, True) for n in node.children if n.type not in [",", "set"]]
|
||||
return "{" + ", ".join(items) + "}"
|
||||
|
||||
elif node.child_count > 0:
|
||||
return "".join(traverse_node(child, True) for child in node.children)
|
||||
else:
|
||||
return get_text(node)
|
||||
|
||||
def extract_arguments(args_node):
|
||||
arguments = {}
|
||||
for child in args_node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# For named parameters
|
||||
name_node, value_node = child.children[0], child.children[2]
|
||||
name = get_text(name_node)
|
||||
value = traverse_node(value_node)
|
||||
if name in arguments:
|
||||
if not isinstance(arguments[name], list):
|
||||
arguments[name] = [arguments[name]]
|
||||
arguments[name].append(value)
|
||||
else:
|
||||
arguments[name] = value
|
||||
# arguments.append({'name': name, 'value': value})
|
||||
elif child.type in ["identifier", "class_literal", "set"]:
|
||||
# For unnamed parameters and handling sets
|
||||
value = traverse_node(child)
|
||||
if None in arguments:
|
||||
if not isinstance(arguments[None], list):
|
||||
arguments[None] = [arguments[None]]
|
||||
arguments[None].append(value)
|
||||
else:
|
||||
arguments[None] = value
|
||||
return arguments
|
||||
|
||||
def traverse(node):
|
||||
if node.type == "method_invocation":
|
||||
# Extract the function name and its arguments
|
||||
method_name = get_text(node.child_by_field_name("name"))
|
||||
class_name_node = node.child_by_field_name("object")
|
||||
if class_name_node:
|
||||
class_name = get_text(class_name_node)
|
||||
function_name = f"{class_name}.{method_name}"
|
||||
else:
|
||||
function_name = method_name
|
||||
arguments_node = node.child_by_field_name("arguments")
|
||||
if arguments_node:
|
||||
arguments = extract_arguments(arguments_node)
|
||||
for key, value in arguments.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
return [{function_name: arguments}]
|
||||
|
||||
else:
|
||||
for child in node.children:
|
||||
result = traverse(child)
|
||||
if result:
|
||||
return result
|
||||
|
||||
result = traverse(root_node)
|
||||
return result if result else {}
|
||||
|
||||
|
||||
def parse_javascript_function_call(source_code):
|
||||
if not source_code.endswith(";"):
|
||||
source_code += ";" # Necessary for the parser not to register an error
|
||||
parser = get_parser("javascript")
|
||||
# Parse the source code
|
||||
tree = parser.parse(bytes(source_code, "utf8"))
|
||||
root_node = tree.root_node
|
||||
if root_node.has_error:
|
||||
raise Exception("Error js parsing the source code.")
|
||||
|
||||
# Function to recursively extract argument details
|
||||
def extract_arguments(node):
|
||||
args = {}
|
||||
for child in node.children:
|
||||
if child.type == "assignment_expression":
|
||||
# Extract left (name) and right (value) parts of the assignment
|
||||
name = child.children[0].text.decode("utf-8")
|
||||
value = child.children[2].text.decode("utf-8")
|
||||
if (value.startswith('"') and value.endswith('"')) or (value.startswith("'") and value.endswith("'")):
|
||||
value = value[1:-1] # Trim the quotation marks
|
||||
if name in args:
|
||||
if not isinstance(args[name], list):
|
||||
args[name] = [args[name]]
|
||||
args[name].append(value)
|
||||
else:
|
||||
args[name] = value
|
||||
|
||||
elif child.type == "identifier" or child.type == "true":
|
||||
# Handle non-named arguments and boolean values
|
||||
value = child.text.decode("utf-8")
|
||||
if None in args:
|
||||
if not isinstance(args[None], list):
|
||||
args[None] = [args[None]]
|
||||
args[None].append(value)
|
||||
else:
|
||||
args[None] = value
|
||||
return args
|
||||
|
||||
# Find the function call and extract its name and arguments
|
||||
if root_node.type == "program":
|
||||
for child in root_node.children:
|
||||
if child.type == "expression_statement":
|
||||
for sub_child in child.children:
|
||||
if sub_child.type == "call_expression":
|
||||
function_name = sub_child.children[0].text.decode("utf8")
|
||||
arguments_node = sub_child.children[1]
|
||||
parameters = extract_arguments(arguments_node)
|
||||
for key, value in parameters.items():
|
||||
if isinstance(value, list):
|
||||
raise Exception("Error: Multiple arguments with the same name are not supported.")
|
||||
result = [{function_name: parameters}]
|
||||
return result
|
||||
|
||||
|
||||
def ast_parse(input_str, language="Python"):
|
||||
if language == "Python":
|
||||
cleaned_input = input_str.strip("[]'")
|
||||
parsed = ast.parse(cleaned_input, mode="eval")
|
||||
extracted = []
|
||||
if isinstance(parsed.body, ast.Call):
|
||||
extracted.append(resolve_ast_call(parsed.body))
|
||||
else:
|
||||
for elem in parsed.body.elts:
|
||||
extracted.append(resolve_ast_call(elem))
|
||||
return extracted
|
||||
elif language == "Java":
|
||||
return parse_java_function_call(input_str[1:-1]) # Remove the [ and ] from the string
|
||||
elif language == "JavaScript":
|
||||
return parse_javascript_function_call(input_str[1:-1])
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported language: {language}")
|
||||
|
||||
|
||||
def resolve_ast_call(elem):
|
||||
# Handle nested attributes for deeply nested module paths
|
||||
func_parts = []
|
||||
func_part = elem.func
|
||||
while isinstance(func_part, ast.Attribute):
|
||||
func_parts.append(func_part.attr)
|
||||
func_part = func_part.value
|
||||
if isinstance(func_part, ast.Name):
|
||||
func_parts.append(func_part.id)
|
||||
func_name = ".".join(reversed(func_parts))
|
||||
args_dict = {}
|
||||
# Parse when args are simply passed as an unnamed dictionary arg
|
||||
for arg in elem.args:
|
||||
if isinstance(arg, ast.Dict):
|
||||
for key, value in zip(arg.keys, arg.values):
|
||||
if isinstance(key, ast.Constant):
|
||||
arg_name = key.value
|
||||
output = resolve_ast_by_type(value)
|
||||
args_dict[arg_name] = output
|
||||
for arg in elem.keywords:
|
||||
output = resolve_ast_by_type(arg.value)
|
||||
args_dict[arg.arg] = output
|
||||
return {func_name: args_dict}
|
||||
|
||||
|
||||
def resolve_ast_by_type(value):
|
||||
if isinstance(value, ast.Constant):
|
||||
if value.value is Ellipsis:
|
||||
output = "..."
|
||||
else:
|
||||
output = value.value
|
||||
elif isinstance(value, ast.UnaryOp):
|
||||
output = -value.operand.value
|
||||
elif isinstance(value, ast.List):
|
||||
output = [resolve_ast_by_type(v) for v in value.elts]
|
||||
elif isinstance(value, ast.Dict):
|
||||
output = {resolve_ast_by_type(k): resolve_ast_by_type(v) for k, v in zip(value.keys, value.values)}
|
||||
elif isinstance(value, ast.NameConstant): # Added this condition to handle boolean values
|
||||
output = value.value
|
||||
elif isinstance(value, ast.BinOp): # Added this condition to handle function calls as arguments
|
||||
output = eval(ast.unparse(value))
|
||||
elif isinstance(value, ast.Name):
|
||||
output = value.id
|
||||
elif isinstance(value, ast.Call):
|
||||
if len(value.keywords) == 0:
|
||||
output = ast.unparse(value)
|
||||
else:
|
||||
output = resolve_ast_call(value)
|
||||
elif isinstance(value, ast.Tuple):
|
||||
output = tuple(resolve_ast_by_type(v) for v in value.elts)
|
||||
elif isinstance(value, ast.Lambda):
|
||||
output = eval(ast.unparse(value.body[0].value))
|
||||
elif isinstance(value, ast.Ellipsis):
|
||||
output = "..."
|
||||
elif isinstance(value, ast.Subscript):
|
||||
try:
|
||||
output = ast.unparse(value.body[0].value)
|
||||
except:
|
||||
output = ast.unparse(value.value) + "[" + ast.unparse(value.slice) + "]"
|
||||
else:
|
||||
raise Exception(f"Unsupported AST type: {type(value)}")
|
||||
return output
|
||||
|
||||
|
||||
def decode_ast(result, language="Python"):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decoded_output = ast_parse(func, language)
|
||||
return decoded_output
|
||||
|
||||
|
||||
def decode_execute(result):
|
||||
func = result
|
||||
func = func.replace("\n", "") # remove new line characters
|
||||
if not func.startswith("["):
|
||||
func = "[" + func
|
||||
if not func.endswith("]"):
|
||||
func = func + "]"
|
||||
decode_output = ast_parse(func)
|
||||
execution_list = []
|
||||
for function_call in decode_output:
|
||||
for key, value in function_call.items():
|
||||
execution_list.append(f"{key}({','.join([f'{k}={repr(v)}' for k, v in value.items()])})")
|
||||
return execution_list
|
|
@ -1,989 +0,0 @@
|
|||
# ruff: noqa
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
# Comment out for now until we actually use the rest checker in evals
|
||||
# import requests # Do not remove this import even though it seems to be unused. It's used in the executable_checker_rest function.
|
||||
|
||||
|
||||
class NoAPIKeyError(Exception):
|
||||
def __init__(self):
|
||||
self.message = "❗️Please fill in the API keys in the function_credential_config.json file. If you do not provide the API keys, the executable test category results will be inaccurate."
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
REAL_TIME_MATCH_ALLOWED_DIFFERENCE = 0.2
|
||||
|
||||
|
||||
JAVA_TYPE_CONVERSION = {
|
||||
"byte": int,
|
||||
"short": int,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"double": float,
|
||||
"long": int,
|
||||
"boolean": bool,
|
||||
"char": str,
|
||||
"Array": list,
|
||||
"ArrayList": list,
|
||||
"Set": set,
|
||||
"HashMap": dict,
|
||||
"Hashtable": dict,
|
||||
"Queue": list, # this can be `queue.Queue` as well, for simplicity we check with list
|
||||
"Stack": list,
|
||||
"String": str,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
JS_TYPE_CONVERSION = {
|
||||
"String": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"Bigint": int,
|
||||
"Boolean": bool,
|
||||
"dict": dict,
|
||||
"array": list,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# We switch to conditional import for the following two imports to avoid unnecessary installations.
|
||||
# User doesn't need to setup the tree-sitter packages if they are not running the test for that language.
|
||||
# from js_type_converter import js_type_converter
|
||||
# from java_type_converter import java_type_converter
|
||||
|
||||
PYTHON_TYPE_MAPPING = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"tuple": list,
|
||||
"dict": dict,
|
||||
"any": str,
|
||||
}
|
||||
|
||||
# This is the list of types that we need to recursively check its values
|
||||
PYTHON_NESTED_TYPE_CHECK_LIST = ["array", "tuple"]
|
||||
|
||||
|
||||
NESTED_CONVERSION_TYPE_LIST = ["Array", "ArrayList", "array"]
|
||||
|
||||
|
||||
#### Helper functions for AST ####
|
||||
def find_description(func_descriptions, name):
|
||||
if type(func_descriptions) == list:
|
||||
for func_description in func_descriptions:
|
||||
if func_description["name"] == name:
|
||||
return func_description
|
||||
return None
|
||||
else:
|
||||
# it is a dict, there is only one function
|
||||
return func_descriptions
|
||||
|
||||
|
||||
def get_possible_answer_type(possible_answer: list):
|
||||
for answer in possible_answer:
|
||||
if answer != "": # Optional parameter
|
||||
return type(answer)
|
||||
return None
|
||||
|
||||
|
||||
def type_checker(
|
||||
param: str,
|
||||
value,
|
||||
possible_answer: list,
|
||||
expected_type_description: str,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
):
|
||||
# NOTE: This type checker only supports nested type checking for one level deep.
|
||||
# We didn't implement recursive type checking for nested types, as it's not needed for the current use case and it's very complex.
|
||||
|
||||
result: Any = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"is_variable": False,
|
||||
"error_type": "type_error:simple",
|
||||
}
|
||||
|
||||
is_variable = False
|
||||
# check for the case where a variable is used instead of a actual value.
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if possible_answer_type != expected_type_converted:
|
||||
is_variable = True
|
||||
|
||||
# value is the same type as in function description
|
||||
if type(value) == expected_type_converted:
|
||||
# We don't need to do recursive check for simple types
|
||||
if nested_type_converted == None:
|
||||
result["is_variable"] = is_variable
|
||||
return result
|
||||
else:
|
||||
for possible_answer_item in possible_answer:
|
||||
flag = True # Each parameter should match to at least one possible answer type.
|
||||
# Here, we assume that each item should be the same type. We could also relax it.
|
||||
if type(possible_answer_item) == list:
|
||||
for value_item in value:
|
||||
checker_result = type_checker(
|
||||
param,
|
||||
value_item,
|
||||
possible_answer_item,
|
||||
str(nested_type_converted),
|
||||
nested_type_converted,
|
||||
None,
|
||||
)
|
||||
if not checker_result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": [], "is_variable": is_variable}
|
||||
|
||||
result["valid"] = False
|
||||
result["error"] = [
|
||||
f"Nested type checking failed for parameter {repr(param)}. Expected outer type {expected_type_description} with inner type {str(nested_type_converted)}. Parameter value: {repr(value)}."
|
||||
]
|
||||
result["error_type"] = "type_error:nested"
|
||||
|
||||
# value is not as expected, check for the case where a variable is used instead of a actual value
|
||||
# use the type in possible_answer as the expected type
|
||||
possible_answer_type = get_possible_answer_type(possible_answer)
|
||||
# if possible_answer only contains optional parameters, we can't determine the type
|
||||
if possible_answer_type != None:
|
||||
# we are being precise here.
|
||||
# in fact, possible_answer_type should always be string, as that's how we treat varibale in possible_answer
|
||||
if type(value) == possible_answer_type:
|
||||
result["is_variable"] = True
|
||||
return result
|
||||
|
||||
result["valid"] = False
|
||||
result["error"].append(
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type {expected_type_description}, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:simple"
|
||||
return result
|
||||
|
||||
|
||||
def standardize_string(input_string: str):
|
||||
# This function standardizes the string by removing all the spaces, ",./-_*^" punctuation, and converting it to lowercase
|
||||
# It will also convert all the single quotes to double quotes
|
||||
# This is used to compare the model output with the possible answers
|
||||
# We don't want to punish model for answer like April 1, 2024 vs April 1,2024, vs April 1 2024
|
||||
regex_string = r"[ \,\.\/\-\_\*\^]"
|
||||
return re.sub(regex_string, "", input_string).lower().replace("'", '"')
|
||||
|
||||
|
||||
def string_checker(param: str, model_output: str, possible_answer: list):
|
||||
standardize_possible_answer = []
|
||||
standardize_model_output = standardize_string(model_output)
|
||||
for i in range(len(possible_answer)):
|
||||
if type(possible_answer[i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[i]))
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}. Case insensitive."
|
||||
],
|
||||
"error_type": "value_error:string",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def list_checker(param: str, model_output: list, possible_answer: list):
|
||||
# Convert the tuple to a list
|
||||
|
||||
standardize_model_output = list(model_output)
|
||||
|
||||
# If the element in the list is a string, we need to standardize it
|
||||
for i in range(len(standardize_model_output)):
|
||||
if type(standardize_model_output[i]) == str:
|
||||
standardize_model_output[i] = standardize_string(model_output[i])
|
||||
|
||||
standardize_possible_answer: Any = []
|
||||
# We also need to standardize the possible answers
|
||||
for i in range(len(possible_answer)):
|
||||
standardize_possible_answer.append([])
|
||||
for j in range(len(possible_answer[i])):
|
||||
if type(possible_answer[i][j]) == str:
|
||||
standardize_possible_answer[i].append(standardize_string(possible_answer[i][j]))
|
||||
else:
|
||||
standardize_possible_answer[i].append(possible_answer[i][j])
|
||||
|
||||
if standardize_model_output not in standardize_possible_answer:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Invalid value for parameter {repr(param)}: {repr(model_output)}. Expected one of {possible_answer}."
|
||||
],
|
||||
"error_type": "value_error:list/tuple",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def dict_checker(param: str, model_output: dict, possible_answers: list):
|
||||
# This function works for simple dictionaries, but not dictionaries with nested dictionaries.
|
||||
# The current dataset only contains simple dictionaries, so this is sufficient.
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
for i in range(len(possible_answers)):
|
||||
if possible_answers[i] == "":
|
||||
continue
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "dict_checker:unclear"}
|
||||
|
||||
flag = True
|
||||
|
||||
possible_answer = possible_answers[i]
|
||||
# possible_anwer is a single dictionary
|
||||
|
||||
for key, value in model_output.items():
|
||||
if key not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
standardize_value = value
|
||||
# If the value is a string, we need to standardize it
|
||||
if type(value) == str:
|
||||
standardize_value = standardize_string(value)
|
||||
|
||||
# We also need to standardize the possible answers if they are string
|
||||
standardize_possible_answer = []
|
||||
for i in range(len(possible_answer[key])):
|
||||
if type(possible_answer[key][i]) == str:
|
||||
standardize_possible_answer.append(standardize_string(possible_answer[key][i]))
|
||||
else:
|
||||
standardize_possible_answer.append(possible_answer[key][i])
|
||||
|
||||
if standardize_value not in standardize_possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(key)}: {repr(value)}. Expected one of {standardize_possible_answer}."
|
||||
)
|
||||
result["error_type"] = "value_error:dict_value"
|
||||
flag = False
|
||||
break
|
||||
|
||||
for key, value in possible_answer.items():
|
||||
if key not in model_output and "" not in value:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing dict key parameter: '{key}'.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "value_error:dict_key"
|
||||
flag = False
|
||||
break
|
||||
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_dict_checker(param: str, model_output: list, possible_answers: list):
|
||||
# This function takes in a list of dictionaries and checks if each dictionary is valid
|
||||
# The order of the dictionaries in the list must match the order of the possible answers
|
||||
|
||||
result = {"valid": False, "error": [], "error_type": "list_dict_checker:unclear"}
|
||||
|
||||
for answer_index in range(len(possible_answers)):
|
||||
flag = True # True means so far, all dictionaries are valid
|
||||
|
||||
# Only proceed if the number of dictionaries in the list matches the number of dictionaries in the possible answers
|
||||
if len(model_output) != len(possible_answers[answer_index]):
|
||||
result["valid"] = False
|
||||
result["error"] = ["Wrong number of dictionaries in the list."]
|
||||
result["error_type"] = "value_error:list_dict_count"
|
||||
flag = False
|
||||
continue
|
||||
|
||||
for dict_index in range(len(model_output)):
|
||||
result = dict_checker(
|
||||
param,
|
||||
model_output[dict_index],
|
||||
[possible_answers[answer_index][dict_index]],
|
||||
)
|
||||
if not result["valid"]:
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def simple_function_checker(
|
||||
func_description: dict,
|
||||
model_output: dict,
|
||||
possible_answer: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
possible_answer = list(possible_answer.values())[0]
|
||||
# Extract function name and parameters details
|
||||
func_name = func_description["name"]
|
||||
param_details = func_description["parameters"]["properties"]
|
||||
required_params = func_description["parameters"]["required"]
|
||||
|
||||
# Initialize a result dictionary
|
||||
result = {
|
||||
"valid": True,
|
||||
"error": [],
|
||||
"error_type": "simple_function_checker:unclear",
|
||||
}
|
||||
|
||||
# Check if function name matches
|
||||
if func_name not in model_output:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Function name {repr(func_name)} not found in model output."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:wrong_func_name"
|
||||
return result
|
||||
|
||||
model_params = model_output[func_name]
|
||||
|
||||
# Check for required parameters in model output
|
||||
for param in required_params:
|
||||
if param not in model_params:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Missing required parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:missing_required"
|
||||
return result
|
||||
|
||||
# Validate types and values for each parameter in model output
|
||||
for param, value in model_params.items():
|
||||
if param not in param_details or param not in possible_answer:
|
||||
result["valid"] = False
|
||||
result["error"].append(f"Unexpected parameter: {repr(param)}.") # type: ignore[attr-defined]
|
||||
result["error_type"] = "simple_function_checker:unexpected_param"
|
||||
return result
|
||||
|
||||
full_param_details = param_details[param]
|
||||
expected_type_description = full_param_details["type"] # This is a string
|
||||
is_variable = False
|
||||
nested_type_converted = None
|
||||
|
||||
if language == "Java":
|
||||
from evals.utils.bfcl.java_type_converter import java_type_converter
|
||||
|
||||
expected_type_converted = JAVA_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JAVA_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:java"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JAVA_TYPE_CONVERSION[nested_type]
|
||||
value = java_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = java_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "JavaScript":
|
||||
from evals.utils.bfcl.js_type_converter import js_type_converter
|
||||
|
||||
expected_type_converted = JS_TYPE_CONVERSION[expected_type_description]
|
||||
|
||||
if expected_type_description in JS_TYPE_CONVERSION:
|
||||
if type(value) != str:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Incorrect type for parameter {repr(param)}. Expected type String, got {type(value).__name__}. Parameter value: {repr(value)}."
|
||||
)
|
||||
result["error_type"] = "type_error:js"
|
||||
return result
|
||||
|
||||
if expected_type_description in NESTED_CONVERSION_TYPE_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = JS_TYPE_CONVERSION[nested_type]
|
||||
value = js_type_converter(value, expected_type_description, nested_type)
|
||||
else:
|
||||
value = js_type_converter(value, expected_type_description)
|
||||
|
||||
elif language == "Python":
|
||||
expected_type_converted = PYTHON_TYPE_MAPPING[expected_type_description]
|
||||
if expected_type_description in PYTHON_NESTED_TYPE_CHECK_LIST:
|
||||
nested_type = param_details[param]["items"]["type"]
|
||||
nested_type_converted = PYTHON_TYPE_MAPPING[nested_type]
|
||||
|
||||
# We convert all tuple value to list when the expected type is tuple.
|
||||
# The conversion is necessary because any tuple in the possible answer would become a list after being processed through json.dump() and json.load().
|
||||
# This does introduce some false positive (eg, when the model provides a list value instead of tuple). We hope to find a better solution in the future.
|
||||
if expected_type_description == "tuple" and type(value) == tuple:
|
||||
value = list(value)
|
||||
|
||||
# Allow python auto conversion from int to float
|
||||
if language == "Python" and expected_type_description == "float" and type(value) == int:
|
||||
value = float(value)
|
||||
|
||||
# Type checking
|
||||
# In fact, we only check for Python here.
|
||||
# Type check for other languages are handled by the type converter, and so their value (after conversion) is always correct.
|
||||
type_check_result = type_checker(
|
||||
param,
|
||||
value,
|
||||
possible_answer[param],
|
||||
expected_type_description,
|
||||
expected_type_converted,
|
||||
nested_type_converted,
|
||||
)
|
||||
is_variable = type_check_result["is_variable"]
|
||||
if not type_check_result["valid"]:
|
||||
return type_check_result
|
||||
|
||||
# It doesn't make sense to special handle dictionaries and list of dictionaries if the value is a variable.
|
||||
# We can just treat the variable as a string and use the normal flow.
|
||||
if not is_variable:
|
||||
# Special handle for dictionaries
|
||||
if expected_type_converted == dict:
|
||||
result = dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for list of dictionaries
|
||||
elif expected_type_converted == list and nested_type_converted == dict:
|
||||
result = list_dict_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Special handle for strings
|
||||
elif expected_type_converted == str:
|
||||
# We don't check for case sensitivity for string, as long as it's not a variable
|
||||
result = string_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
elif expected_type_converted == list:
|
||||
result = list_checker(param, value, possible_answer[param])
|
||||
if not result["valid"]:
|
||||
return result
|
||||
continue
|
||||
|
||||
# Check if the value is within the possible answers
|
||||
if value not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Invalid value for parameter {repr(param)}: {repr(value)}. Expected one of {possible_answer[param]}."
|
||||
)
|
||||
result["error_type"] = "value_error:others"
|
||||
return result
|
||||
|
||||
# Check for optional parameters not provided but allowed
|
||||
for param in possible_answer:
|
||||
if param not in model_params and "" not in possible_answer[param]:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Optional parameter {repr(param)} not provided and not marked as optional."
|
||||
)
|
||||
result["error_type"] = "simple_function_checker:missing_optional"
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def parallel_function_checker_enforce_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: dict,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_enforce_order:wrong_count",
|
||||
}
|
||||
|
||||
func_name_list = list(possible_answers.keys())
|
||||
possible_answers_list = []
|
||||
|
||||
for key, value in possible_answers.items():
|
||||
possible_answers_list.append({key: value})
|
||||
|
||||
for i in range(len(possible_answers_list)):
|
||||
func_description = find_description(func_descriptions, func_name_list[i])
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[i],
|
||||
possible_answers_list[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
if not result["valid"]:
|
||||
return result
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def parallel_function_checker_no_order(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "parallel_function_checker_no_order:wrong_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
|
||||
# We go throught the possible answers one by one, and eliminate the model output that matches the possible answer
|
||||
# It must be this way because we need ground truth to fetch the correct function description
|
||||
for i in range(len(possible_answers)):
|
||||
# possible_answers[i] is a dictionary with only one key
|
||||
func_name_expected = list(possible_answers[i].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
|
||||
all_errors = []
|
||||
|
||||
for index in range(len(model_output)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = simple_function_checker(
|
||||
func_description,
|
||||
model_output[index],
|
||||
possible_answers[i],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_output_item": model_output[index],
|
||||
"possible_answer_item": possible_answers[i],
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(model_output)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "parallel_function_checker_no_order:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
|
||||
|
||||
def multiple_function_checker(
|
||||
func_descriptions: list,
|
||||
model_output: list,
|
||||
possible_answers: list,
|
||||
language: str,
|
||||
model_name: str,
|
||||
):
|
||||
if len(model_output) != len(possible_answers):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "multiple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
# possible_answers is a list of only one dictionary with only one key
|
||||
func_name_expected = list(possible_answers[0].keys())[0]
|
||||
func_description = find_description(func_descriptions, func_name_expected)
|
||||
return simple_function_checker(
|
||||
func_description,
|
||||
model_output[0],
|
||||
possible_answers[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def patten_matcher(exec_output, expected_result, function_call, is_sanity_check):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
if type(exec_output) != type(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result type for {repr(function_call)}. Expected type: {type(expected_result)}, but got: {type(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == dict:
|
||||
# We loose the requirement for the sanity check as the expected result used in the sanity check might not be the most up-to-date one.
|
||||
# This happens when the key is a timestamp or a random number.
|
||||
if is_sanity_check:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
else:
|
||||
return result
|
||||
|
||||
for key, value in expected_result.items():
|
||||
if key not in exec_output:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not found in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_key_not_found",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
for key, value in exec_output.items():
|
||||
if key not in expected_result:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type Dict, but key {repr(key)} not expected in the model output."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:dict_extra_key",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
if type(exec_output) == list:
|
||||
if len(exec_output) != len(expected_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong execution result pattern for {repr(function_call)}. Expect type list, but wrong number of elements in the output. Expected length: {len(expected_result)}, but got: {len(exec_output)}."
|
||||
],
|
||||
"error_type": "executable_checker:wrong_result_type:list_length",
|
||||
"model_executed_output": exec_output,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
#### Helper functions for Exec ####
|
||||
def executable_checker_simple(
|
||||
function_call: str,
|
||||
expected_result,
|
||||
expected_result_type: str,
|
||||
is_sanity_check=False,
|
||||
):
|
||||
result = {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
exec_dict: Any = {}
|
||||
|
||||
try:
|
||||
exec(
|
||||
"from executable_python_function import *" + "\nresult=" + function_call,
|
||||
exec_dict,
|
||||
)
|
||||
exec_output = exec_dict["result"]
|
||||
except NoAPIKeyError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Error in execution: {repr(function_call)}. Error: {str(e)}"
|
||||
)
|
||||
result["error_type"] = "executable_checker:execution_error"
|
||||
return result
|
||||
|
||||
# We need to special handle the case where the execution result is a tuple and convert it to a list
|
||||
# Because when json is stored, the tuple is converted to a list, and so the expected result is a list when loaded from json
|
||||
if isinstance(exec_output, tuple):
|
||||
exec_output = list(exec_output)
|
||||
|
||||
if expected_result_type == "exact_match":
|
||||
if exec_output != expected_result:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
elif expected_result_type == "real_time_match":
|
||||
# Allow for 5% difference
|
||||
if (type(expected_result) == float or type(expected_result) == int) and (
|
||||
type(exec_output) == float or type(exec_output) == int
|
||||
):
|
||||
if not (
|
||||
expected_result * (1 - REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
<= exec_output
|
||||
<= expected_result * (1 + REAL_TIME_MATCH_ALLOWED_DIFFERENCE)
|
||||
):
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. {REAL_TIME_MATCH_ALLOWED_DIFFERENCE * 100}% difference allowed."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
else:
|
||||
result["valid"] = False
|
||||
result["error"].append( # type: ignore[attr-defined]
|
||||
f"Wrong execution result for {repr(function_call)}. Expected: {expected_result}, but got: {exec_output}. Type needs to be float or int for real time match criteria."
|
||||
)
|
||||
result["error_type"] = "executable_checker:wrong_result_real_time"
|
||||
result["model_executed_output"] = exec_output
|
||||
return result
|
||||
|
||||
else:
|
||||
# structural match
|
||||
pattern_match_result = patten_matcher(exec_output, expected_result, function_call, is_sanity_check)
|
||||
if not pattern_match_result["valid"]:
|
||||
return pattern_match_result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def executable_checker_parallel_no_order(
|
||||
decoded_result: list, expected_exec_result: list, expected_exec_result_type: list
|
||||
):
|
||||
if len(decoded_result) != len(expected_exec_result):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Wrong number of functions provided. Expected {len(expected_exec_result)}, but got {len(decoded_result)}."
|
||||
],
|
||||
"error_type": "value_error:exec_result_count",
|
||||
}
|
||||
|
||||
matched_indices = []
|
||||
for i in range(len(expected_exec_result)):
|
||||
all_errors = []
|
||||
for index in range(len(decoded_result)):
|
||||
if index in matched_indices:
|
||||
continue
|
||||
|
||||
result = executable_checker_simple(
|
||||
decoded_result[index],
|
||||
expected_exec_result[i],
|
||||
expected_exec_result_type[i],
|
||||
False,
|
||||
)
|
||||
|
||||
if result["valid"]:
|
||||
matched_indices.append(index)
|
||||
break
|
||||
else:
|
||||
all_errors.append(
|
||||
{
|
||||
f"Model Result Index {index}": {
|
||||
"sub_error": result["error"],
|
||||
"sub_error_type": result["error_type"],
|
||||
"model_executed_output": (
|
||||
result["model_executed_output"] if "model_executed_output" in result else None
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
if not result["valid"]:
|
||||
considered_indices = [i for i in range(len(decoded_result)) if i not in matched_indices]
|
||||
all_errors.insert(
|
||||
0,
|
||||
f"Could not find a matching function among index {considered_indices} of model output for index {i} of possible answers.", # type: ignore[arg-type]
|
||||
)
|
||||
return {
|
||||
"valid": False,
|
||||
"error": all_errors,
|
||||
"error_type": "executable_checker:cannot_find_match",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": [], "error_type": "executable_checker:unclear"}
|
||||
|
||||
|
||||
#### Main function ####
|
||||
def executable_checker_rest(func_call, idx):
|
||||
# Move this here for now to avoid needing to read this file / fix paths to be relative to dataset_dir. Fix when it's actually needed / used.
|
||||
EVAL_GROUND_TRUTH_PATH = "/mnt/wsfuse/fair_llm_v2/datasets/eval/bfcl/rest-eval-response_v5.jsonl" # Ground truth file for v5 for rest execution
|
||||
with open(EVAL_GROUND_TRUTH_PATH, "r") as f:
|
||||
EVAL_GROUND_TRUTH = f.readlines()
|
||||
if "https://geocode.maps.co" in func_call:
|
||||
time.sleep(2)
|
||||
if "requests_get" in func_call:
|
||||
func_call = func_call.replace("requests_get", "requests.get")
|
||||
try:
|
||||
response = eval(func_call)
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution failed. {str(e)}"],
|
||||
"error_type": "executable_checker_rest:execution_error",
|
||||
}
|
||||
|
||||
try:
|
||||
if response.status_code == 200:
|
||||
eval_GT_json = json.loads(EVAL_GROUND_TRUTH[idx])
|
||||
try:
|
||||
if isinstance(eval_GT_json, dict):
|
||||
if isinstance(response.json(), dict):
|
||||
if set(eval_GT_json.keys()) == set(response.json().keys()):
|
||||
return {"valid": True, "error": [], "error_type": ""}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dictionary, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
|
||||
elif isinstance(eval_GT_json, list):
|
||||
if isinstance(response.json(), list):
|
||||
if len(eval_GT_json) != len(response.json()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Response list length inconsistency."],
|
||||
"error_type": "value_error:exec_result_rest_count",
|
||||
}
|
||||
|
||||
else:
|
||||
for i in range(len(eval_GT_json)):
|
||||
if set(eval_GT_json[i].keys()) != set(response.json()[i].keys()):
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Key inconsistency"],
|
||||
"error_type": "executable_checker_rest:wrong_key",
|
||||
}
|
||||
|
||||
return {"valid": True, "error": []}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Expected dict or list, but got {type(response.json())}"],
|
||||
"error_type": "executable_checker_rest:wrong_type",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [
|
||||
f"Error in execution and type checking. Status code: {response.status_code}. Error: {str(e)}"
|
||||
],
|
||||
"error_type": "executable_checker_rest:response_format_error",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Execution result status code is not 200, got {response.status_code}"],
|
||||
"error_type": "executable_checker_rest:wrong_status_code",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": [f"Cannot get status code of the response. Error: {str(e)}"],
|
||||
"error_type": "executable_checker_rest:cannot_get_status_code",
|
||||
}
|
||||
|
||||
|
||||
def ast_checker(func_description, model_output, possible_answer, language, test_category, model_name):
|
||||
if "parallel" in test_category:
|
||||
return parallel_function_checker_no_order(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
elif "multiple" in test_category:
|
||||
return multiple_function_checker(func_description, model_output, possible_answer, language, model_name)
|
||||
|
||||
else:
|
||||
if len(model_output) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_function_checker:wrong_count",
|
||||
}
|
||||
|
||||
return simple_function_checker(
|
||||
func_description[0],
|
||||
model_output[0],
|
||||
possible_answer[0],
|
||||
language,
|
||||
model_name,
|
||||
)
|
||||
|
||||
|
||||
def exec_checker(decoded_result: list, func_description: dict, test_category: str):
|
||||
if "multiple" in test_category or "parallel" in test_category:
|
||||
return executable_checker_parallel_no_order(
|
||||
decoded_result,
|
||||
func_description["execution_result"],
|
||||
func_description["execution_result_type"],
|
||||
)
|
||||
|
||||
else:
|
||||
if len(decoded_result) != 1:
|
||||
return {
|
||||
"valid": False,
|
||||
"error": ["Wrong number of functions."],
|
||||
"error_type": "simple_exec_checker:wrong_count",
|
||||
}
|
||||
return executable_checker_simple(
|
||||
decoded_result[0],
|
||||
func_description["execution_result"][0],
|
||||
func_description["execution_result_type"][0],
|
||||
False,
|
||||
)
|
||||
|
||||
|
||||
def is_empty_output(decoded_output):
|
||||
# This function is a patch to the ast decoder for relevance detection
|
||||
# Sometimes the ast decoder will parse successfully, but the input doens't really have a function call
|
||||
# [], [{}], and anything that is not in function calling format is considered empty (and thus should be marked as correct)
|
||||
if not is_function_calling_format_output(decoded_output):
|
||||
return True
|
||||
if len(decoded_output) == 0:
|
||||
return True
|
||||
if len(decoded_output) == 1 and len(decoded_output[0]) == 0:
|
||||
return True
|
||||
|
||||
|
||||
def is_function_calling_format_output(decoded_output):
|
||||
# Ensure the output is a list of dictionaries
|
||||
if type(decoded_output) == list:
|
||||
for item in decoded_output:
|
||||
if type(item) != dict:
|
||||
return False
|
||||
return True
|
||||
return False
|
|
@ -1,40 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Tree-sitter changes its API with unfortunate frequency. Modules that need it should
|
||||
import it from here so that we can centrally manage things as necessary.
|
||||
"""
|
||||
|
||||
# These currently work with tree-sitter 0.23.0
|
||||
# NOTE: Don't import tree-sitter or any of the language modules in the main module
|
||||
# because not all environments have them. Import lazily inside functions where needed.
|
||||
|
||||
import importlib
|
||||
import typing
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import tree_sitter
|
||||
|
||||
|
||||
def get_language(language: str) -> "tree_sitter.Language":
|
||||
import tree_sitter
|
||||
|
||||
language_module_name = f"tree_sitter_{language}"
|
||||
try:
|
||||
language_module = importlib.import_module(language_module_name)
|
||||
except ModuleNotFoundError as exc:
|
||||
raise ValueError(
|
||||
f"Language {language} is not found. Please install the tree-sitter-{language} package."
|
||||
) from exc
|
||||
return tree_sitter.Language(language_module.language())
|
||||
|
||||
|
||||
def get_parser(language: str, **kwargs) -> "tree_sitter.Parser":
|
||||
import tree_sitter
|
||||
|
||||
lang = get_language(language)
|
||||
return tree_sitter.Parser(lang, **kwargs)
|
|
@ -30,11 +30,11 @@ from llama_stack.providers.utils.kvstore.api import KVStore
|
|||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
RERANKER_TYPE_RRF,
|
||||
RERANKER_TYPE_WEIGHTED,
|
||||
ChunkForDeletion,
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator
|
||||
|
||||
logger = get_logger(name=__name__, category="vector_io")
|
||||
|
||||
|
@ -66,59 +66,6 @@ def _create_sqlite_connection(db_path):
|
|||
return connection
|
||||
|
||||
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""Normalize scores to [0,1] range using min-max normalization."""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score = min(scores.values())
|
||||
max_score = max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
|
||||
def _weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses weighted average of scores."""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = _normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = _normalize_scores(keyword_scores)
|
||||
|
||||
return {
|
||||
doc_id: (alpha * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ ((1 - alpha) * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
|
||||
def _rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""ReRanker that uses Reciprocal Rank Fusion."""
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1 for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
|
||||
def _make_sql_identifier(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
|
||||
|
@ -398,14 +345,10 @@ class SQLiteVecIndex(EmbeddingIndex):
|
|||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the specified reranker
|
||||
if reranker_type == RERANKER_TYPE_WEIGHTED:
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
combined_scores = _weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
combined_scores = _rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Sort by combined score and get top k results
|
||||
sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)
|
||||
|
|
|
@ -40,8 +40,9 @@ def available_providers() -> list[ProviderSpec]:
|
|||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
# CrossEncoder depends on torchao.quantization
|
||||
pip_packages=[
|
||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu",
|
||||
"torch torchvision torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu",
|
||||
"sentence-transformers --no-deps",
|
||||
],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
|
@ -115,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks",
|
||||
pip_packages=[
|
||||
"fireworks-ai",
|
||||
"fireworks-ai<=0.18.0",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.fireworks",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||
|
@ -291,7 +292,7 @@ Available Models:
|
|||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="watsonx",
|
||||
pip_packages=["ibm_watson_machine_learning"],
|
||||
pip_packages=["ibm_watsonx_ai"],
|
||||
module="llama_stack.providers.remote.inference.watsonx",
|
||||
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||
|
|
|
@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec
|
|||
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
|
||||
torchtune_def = dict(
|
||||
api=Api.post_training,
|
||||
pip_packages=["torchtune==0.5.0", "torchao==0.8.0", "numpy"],
|
||||
pip_packages=["numpy"],
|
||||
module="llama_stack.providers.inline.post_training.torchtune",
|
||||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
|
@ -23,56 +23,39 @@ torchtune_def = dict(
|
|||
description="TorchTune-based post-training provider for fine-tuning and optimizing models using Meta's TorchTune framework.",
|
||||
)
|
||||
|
||||
huggingface_def = dict(
|
||||
api=Api.post_training,
|
||||
pip_packages=["trl", "transformers", "peft", "datasets"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
)
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**{ # type: ignore
|
||||
**torchtune_def,
|
||||
"provider_type": "inline::torchtune-cpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], torchtune_def["pip_packages"])
|
||||
+ ["torch torchtune==0.5.0 torchao==0.8.0 --index-url https://download.pytorch.org/whl/cpu"]
|
||||
+ ["torch torchtune>=0.5.0 torchao>=0.12.0 --extra-index-url https://download.pytorch.org/whl/cpu"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**huggingface_def,
|
||||
"provider_type": "inline::huggingface-cpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], huggingface_def["pip_packages"])
|
||||
+ ["torch --index-url https://download.pytorch.org/whl/cpu"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**{ # type: ignore
|
||||
**torchtune_def,
|
||||
"provider_type": "inline::torchtune-gpu",
|
||||
"pip_packages": (
|
||||
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune==0.5.0 torchao==0.8.0"]
|
||||
cast(list[str], torchtune_def["pip_packages"]) + ["torch torchtune>=0.5.0 torchao>=0.12.0"]
|
||||
),
|
||||
},
|
||||
),
|
||||
InlineProviderSpec(
|
||||
**{
|
||||
**huggingface_def,
|
||||
"provider_type": "inline::huggingface-gpu",
|
||||
"pip_packages": (cast(list[str], huggingface_def["pip_packages"]) + ["torch"]),
|
||||
},
|
||||
api=Api.post_training,
|
||||
provider_type="inline::huggingface-gpu",
|
||||
pip_packages=["trl", "transformers", "peft", "datasets", "torch"],
|
||||
module="llama_stack.providers.inline.post_training.huggingface",
|
||||
config_class="llama_stack.providers.inline.post_training.huggingface.HuggingFacePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.post_training,
|
||||
|
|
|
@ -404,6 +404,60 @@ That means you'll get fast and efficient vector retrieval.
|
|||
- Easy to use
|
||||
- Fully integrated with Llama Stack
|
||||
|
||||
There are three implementations of search for PGVectoIndex available:
|
||||
|
||||
1. Vector Search:
|
||||
- How it works:
|
||||
- Uses PostgreSQL's vector extension (pgvector) to perform similarity search
|
||||
- Compares query embeddings against stored embeddings using Cosine distance or other distance metrics
|
||||
- Eg. SQL query: SELECT document, embedding <=> %s::vector AS distance FROM table ORDER BY distance
|
||||
|
||||
-Characteristics:
|
||||
- Semantic understanding - finds documents similar in meaning even if they don't share keywords
|
||||
- Works with high-dimensional vector embeddings (typically 768, 1024, or higher dimensions)
|
||||
- Best for: Finding conceptually related content, handling synonyms, cross-language search
|
||||
|
||||
2. Keyword Search
|
||||
- How it works:
|
||||
- Uses PostgreSQL's full-text search capabilities with tsvector and ts_rank
|
||||
- Converts text to searchable tokens using to_tsvector('english', text). Default language is English.
|
||||
- Eg. SQL query: SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
|
||||
- Characteristics:
|
||||
- Lexical matching - finds exact keyword matches and variations
|
||||
- Uses GIN (Generalized Inverted Index) for fast text search performance
|
||||
- Scoring: Uses PostgreSQL's ts_rank function for relevance scoring
|
||||
- Best for: Exact term matching, proper names, technical terms, Boolean-style queries
|
||||
|
||||
3. Hybrid Search
|
||||
- How it works:
|
||||
- Combines both vector and keyword search results
|
||||
- Runs both searches independently, then merges results using configurable reranking
|
||||
|
||||
- Two reranking strategies available:
|
||||
- Reciprocal Rank Fusion (RRF) - (default: 60.0)
|
||||
- Weighted Average - (default: 0.5)
|
||||
|
||||
- Characteristics:
|
||||
- Best of both worlds: semantic understanding + exact matching
|
||||
- Documents appearing in both searches get boosted scores
|
||||
- Configurable balance between semantic and lexical matching
|
||||
- Best for: General-purpose search where you want both precision and recall
|
||||
|
||||
4. Database Schema
|
||||
The PGVector implementation stores data optimized for all three search types:
|
||||
CREATE TABLE vector_store_xxx (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB, -- Original document
|
||||
embedding vector(dimension), -- For vector search
|
||||
content_text TEXT, -- Raw text content
|
||||
tokenized_content TSVECTOR -- For keyword search
|
||||
);
|
||||
|
||||
-- Indexes for performance
|
||||
CREATE INDEX content_gin_idx ON table USING GIN(tokenized_content); -- Keyword search
|
||||
-- Vector index created automatically by pgvector
|
||||
|
||||
## Usage
|
||||
|
||||
To use PGVector in your Llama Stack project, follow these steps:
|
||||
|
@ -412,6 +466,25 @@ To use PGVector in your Llama Stack project, follow these steps:
|
|||
2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector).
|
||||
3. Start storing and querying vectors.
|
||||
|
||||
## This is an example how you can set up your environment for using PGVector
|
||||
|
||||
1. Export env vars:
|
||||
```bash
|
||||
export ENABLE_PGVECTOR=true
|
||||
export PGVECTOR_HOST=localhost
|
||||
export PGVECTOR_PORT=5432
|
||||
export PGVECTOR_DB=llamastack
|
||||
export PGVECTOR_USER=llamastack
|
||||
export PGVECTOR_PASSWORD=llamastack
|
||||
```
|
||||
|
||||
2. Create DB:
|
||||
```bash
|
||||
psql -h localhost -U postgres -c "CREATE ROLE llamastack LOGIN PASSWORD 'llamastack';"
|
||||
psql -h localhost -U postgres -c "CREATE DATABASE llamastack OWNER llamastack;"
|
||||
psql -h localhost -U llamastack -d llamastack -c "CREATE EXTENSION IF NOT EXISTS vector;"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
You can install PGVector using docker:
|
||||
|
@ -449,6 +522,7 @@ Weaviate supports:
|
|||
- Metadata filtering
|
||||
- Multi-modal retrieval
|
||||
|
||||
|
||||
## Usage
|
||||
|
||||
To use Weaviate in your Llama Stack project, follow these steps:
|
||||
|
|
|
@ -6,15 +6,14 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.core.datatypes import AccessRule, Api
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any]):
|
||||
async def get_adapter_impl(config: S3FilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule] | None = None):
|
||||
from .files import S3FilesImpl
|
||||
|
||||
# TODO: authorization policies and user separation
|
||||
impl = S3FilesImpl(config)
|
||||
impl = S3FilesImpl(config, policy or [])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -4,9 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from datetime import UTC, datetime
|
||||
from typing import Annotated, Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import BotoCoreError, ClientError, NoCredentialsError
|
||||
|
@ -15,14 +15,17 @@ from fastapi import File, Form, Response, UploadFile
|
|||
from llama_stack.apis.common.errors import ResourceNotFoundError
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.files import (
|
||||
ExpiresAfter,
|
||||
Files,
|
||||
ListOpenAIFileResponse,
|
||||
OpenAIFileDeleteResponse,
|
||||
OpenAIFileObject,
|
||||
OpenAIFilePurpose,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl
|
||||
|
||||
from .config import S3FilesImplConfig
|
||||
|
||||
|
@ -83,22 +86,85 @@ async def _create_bucket_if_not_exists(client: boto3.client, config: S3FilesImpl
|
|||
raise RuntimeError(f"Failed to access S3 bucket '{config.bucket_name}': {e}") from e
|
||||
|
||||
|
||||
def _make_file_object(
|
||||
*,
|
||||
id: str,
|
||||
filename: str,
|
||||
purpose: str,
|
||||
bytes: int,
|
||||
created_at: int,
|
||||
expires_at: int,
|
||||
**kwargs: Any, # here to ignore any additional fields, e.g. extra fields from AuthorizedSqlStore
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Construct an OpenAIFileObject and normalize expires_at.
|
||||
|
||||
If expires_at is greater than the max we treat it as no-expiration and
|
||||
return None for expires_at.
|
||||
|
||||
The OpenAI spec says expires_at type is Integer, but the implementation
|
||||
will return None for no expiration.
|
||||
"""
|
||||
obj = OpenAIFileObject(
|
||||
id=id,
|
||||
filename=filename,
|
||||
purpose=OpenAIFilePurpose(purpose),
|
||||
bytes=bytes,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if obj.expires_at is not None and obj.expires_at > (obj.created_at + ExpiresAfter.MAX):
|
||||
obj.expires_at = None # type: ignore
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
class S3FilesImpl(Files):
|
||||
"""S3-based implementation of the Files API."""
|
||||
|
||||
# TODO: implement expiration, for now a silly offset
|
||||
_SILLY_EXPIRATION_OFFSET = 100 * 365 * 24 * 60 * 60
|
||||
|
||||
def __init__(self, config: S3FilesImplConfig) -> None:
|
||||
def __init__(self, config: S3FilesImplConfig, policy: list[AccessRule]) -> None:
|
||||
self._config = config
|
||||
self.policy = policy
|
||||
self._client: boto3.client | None = None
|
||||
self._sql_store: SqlStore | None = None
|
||||
self._sql_store: AuthorizedSqlStore | None = None
|
||||
|
||||
def _now(self) -> int:
|
||||
"""Return current UTC timestamp as int seconds."""
|
||||
return int(datetime.now(UTC).timestamp())
|
||||
|
||||
async def _get_file(self, file_id: str, return_expired: bool = False) -> dict[str, Any]:
|
||||
where: dict[str, str | dict] = {"id": file_id}
|
||||
if not return_expired:
|
||||
where["expires_at"] = {">": self._now()}
|
||||
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
return row
|
||||
|
||||
async def _delete_file(self, file_id: str) -> None:
|
||||
"""Delete a file from S3 and the database."""
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=file_id,
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
async def _delete_if_expired(self, file_id: str) -> None:
|
||||
"""If the file exists and is expired, delete it."""
|
||||
if row := await self._get_file(file_id, return_expired=True):
|
||||
if (expires_at := row.get("expires_at")) and expires_at <= self._now():
|
||||
await self._delete_file(file_id)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self._client = _create_s3_client(self._config)
|
||||
await _create_bucket_if_not_exists(self._client, self._config)
|
||||
|
||||
self._sql_store = sqlstore_impl(self._config.metadata_store)
|
||||
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
|
||||
await self._sql_store.create_table(
|
||||
"openai_files",
|
||||
{
|
||||
|
@ -121,7 +187,7 @@ class S3FilesImpl(Files):
|
|||
return self._client
|
||||
|
||||
@property
|
||||
def sql_store(self) -> SqlStore:
|
||||
def sql_store(self) -> AuthorizedSqlStore:
|
||||
assert self._sql_store is not None, "Provider not initialized"
|
||||
return self._sql_store
|
||||
|
||||
|
@ -129,27 +195,47 @@ class S3FilesImpl(Files):
|
|||
self,
|
||||
file: Annotated[UploadFile, File()],
|
||||
purpose: Annotated[OpenAIFilePurpose, Form()],
|
||||
expires_after_anchor: Annotated[str | None, Form(alias="expires_after[anchor]")] = None,
|
||||
expires_after_seconds: Annotated[int | None, Form(alias="expires_after[seconds]")] = None,
|
||||
) -> OpenAIFileObject:
|
||||
file_id = f"file-{uuid.uuid4().hex}"
|
||||
|
||||
filename = getattr(file, "filename", None) or "uploaded_file"
|
||||
|
||||
created_at = int(time.time())
|
||||
expires_at = created_at + self._SILLY_EXPIRATION_OFFSET
|
||||
created_at = self._now()
|
||||
|
||||
expires_after = None
|
||||
if expires_after_anchor is not None or expires_after_seconds is not None:
|
||||
# we use ExpiresAfter to validate input
|
||||
expires_after = ExpiresAfter(
|
||||
anchor=expires_after_anchor, # type: ignore[arg-type]
|
||||
seconds=expires_after_seconds, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# the default is no expiration.
|
||||
# to implement no expiration we set an expiration beyond the max.
|
||||
# we'll hide this fact from users when returning the file object.
|
||||
expires_at = created_at + ExpiresAfter.MAX * 42
|
||||
# the default for BATCH files is 30 days, which happens to be the expiration max.
|
||||
if purpose == OpenAIFilePurpose.BATCH:
|
||||
expires_at = created_at + ExpiresAfter.MAX
|
||||
|
||||
if expires_after is not None:
|
||||
expires_at = created_at + expires_after.seconds
|
||||
|
||||
content = await file.read()
|
||||
file_size = len(content)
|
||||
|
||||
await self.sql_store.insert(
|
||||
"openai_files",
|
||||
{
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
},
|
||||
)
|
||||
entry: dict[str, Any] = {
|
||||
"id": file_id,
|
||||
"filename": filename,
|
||||
"purpose": purpose.value,
|
||||
"bytes": file_size,
|
||||
"created_at": created_at,
|
||||
"expires_at": expires_at,
|
||||
}
|
||||
|
||||
await self.sql_store.insert("openai_files", entry)
|
||||
|
||||
try:
|
||||
self.client.put_object(
|
||||
|
@ -163,14 +249,7 @@ class S3FilesImpl(Files):
|
|||
|
||||
raise RuntimeError(f"Failed to upload file to S3: {e}") from e
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=file_id,
|
||||
filename=filename,
|
||||
purpose=purpose,
|
||||
bytes=file_size,
|
||||
created_at=created_at,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
return _make_file_object(**entry)
|
||||
|
||||
async def openai_list_files(
|
||||
self,
|
||||
|
@ -183,29 +262,20 @@ class S3FilesImpl(Files):
|
|||
if not order:
|
||||
order = Order.desc
|
||||
|
||||
where_conditions = {}
|
||||
where_conditions: dict[str, Any] = {"expires_at": {">": self._now()}}
|
||||
if purpose:
|
||||
where_conditions["purpose"] = purpose.value
|
||||
|
||||
paginated_result = await self.sql_store.fetch_all(
|
||||
table="openai_files",
|
||||
where=where_conditions if where_conditions else None,
|
||||
policy=self.policy,
|
||||
where=where_conditions,
|
||||
order_by=[("created_at", order.value)],
|
||||
cursor=("id", after) if after else None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
files = [
|
||||
OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
for row in paginated_result.data
|
||||
]
|
||||
files = [_make_file_object(**row) for row in paginated_result.data]
|
||||
|
||||
return ListOpenAIFileResponse(
|
||||
data=files,
|
||||
|
@ -216,41 +286,20 @@ class S3FilesImpl(Files):
|
|||
)
|
||||
|
||||
async def openai_retrieve_file(self, file_id: str) -> OpenAIFileObject:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
return OpenAIFileObject(
|
||||
id=row["id"],
|
||||
filename=row["filename"],
|
||||
purpose=OpenAIFilePurpose(row["purpose"]),
|
||||
bytes=row["bytes"],
|
||||
created_at=row["created_at"],
|
||||
expires_at=row["expires_at"],
|
||||
)
|
||||
await self._delete_if_expired(file_id)
|
||||
row = await self._get_file(file_id)
|
||||
return _make_file_object(**row)
|
||||
|
||||
async def openai_delete_file(self, file_id: str) -> OpenAIFileDeleteResponse:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
|
||||
try:
|
||||
self.client.delete_object(
|
||||
Bucket=self._config.bucket_name,
|
||||
Key=row["id"],
|
||||
)
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] != "NoSuchKey":
|
||||
raise RuntimeError(f"Failed to delete file from S3: {e}") from e
|
||||
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
|
||||
await self._delete_if_expired(file_id)
|
||||
_ = await self._get_file(file_id) # raises if not found
|
||||
await self._delete_file(file_id)
|
||||
return OpenAIFileDeleteResponse(id=file_id, deleted=True)
|
||||
|
||||
async def openai_retrieve_file_content(self, file_id: str) -> Response:
|
||||
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
|
||||
if not row:
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()")
|
||||
await self._delete_if_expired(file_id)
|
||||
|
||||
row = await self._get_file(file_id)
|
||||
|
||||
try:
|
||||
response = self.client.get_object(
|
||||
|
@ -261,7 +310,7 @@ class S3FilesImpl(Files):
|
|||
content = response["Body"].read()
|
||||
except ClientError as e:
|
||||
if e.response["Error"]["Code"] == "NoSuchKey":
|
||||
await self.sql_store.delete("openai_files", where={"id": file_id})
|
||||
await self._delete_file(file_id)
|
||||
raise ResourceNotFoundError(file_id, "File", "files.list()") from e
|
||||
raise RuntimeError(f"Failed to download file from S3: {e}") from e
|
||||
|
||||
|
|
|
@ -41,10 +41,10 @@ client.initialize()
|
|||
|
||||
### Create Completion
|
||||
|
||||
> Note on Completion API
|
||||
>
|
||||
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
|
||||
The following example shows how to create a completion for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> The hosted NVIDIA Llama NIMs (for example ```meta-llama/Llama-3.1-8B-Instruct```) that have ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` do not support the ```completion``` method, while locally deployed NIMs do.
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
|
@ -60,6 +60,8 @@ print(f"Response: {response.content}")
|
|||
|
||||
### Create Chat Completion
|
||||
|
||||
The following example shows how to create a chat completion for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
response = client.inference.chat_completion(
|
||||
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
||||
|
@ -82,6 +84,9 @@ print(f"Response: {response.completion_message.content}")
|
|||
```
|
||||
|
||||
### Tool Calling Example ###
|
||||
|
||||
The following example shows how to do tool calling for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
|
||||
|
||||
|
@ -117,6 +122,9 @@ if tool_response.completion_message.tool_calls:
|
|||
```
|
||||
|
||||
### Structured Output Example
|
||||
|
||||
The following example shows how to do structured output for an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
|
||||
|
||||
|
@ -149,8 +157,10 @@ print(f"Structured Response: {structured_response.completion_message.content}")
|
|||
```
|
||||
|
||||
### Create Embeddings
|
||||
> Note on OpenAI embeddings compatibility
|
||||
>
|
||||
|
||||
The following example shows how to create embeddings for an NVIDIA NIM.
|
||||
|
||||
> [!NOTE]
|
||||
> NVIDIA asymmetric embedding models (e.g., `nvidia/llama-3.2-nv-embedqa-1b-v2`) require an `input_type` parameter not present in the standard OpenAI embeddings API. The NVIDIA Inference Adapter automatically sets `input_type="query"` when using the OpenAI-compatible embeddings endpoint for NVIDIA. For passage embeddings, use the `embeddings` API with `task_type="document"`.
|
||||
|
||||
```python
|
||||
|
@ -160,4 +170,42 @@ response = client.inference.embeddings(
|
|||
task_type="query",
|
||||
)
|
||||
print(f"Embeddings: {response.embeddings}")
|
||||
```
|
||||
```
|
||||
|
||||
### Vision Language Models Example
|
||||
|
||||
The following example shows how to run vision inference by using an NVIDIA NIM.
|
||||
|
||||
```python
|
||||
def load_image_as_base64(image_path):
|
||||
with open(image_path, "rb") as image_file:
|
||||
img_bytes = image_file.read()
|
||||
return base64.b64encode(img_bytes).decode("utf-8")
|
||||
|
||||
|
||||
image_path = {path_to_the_image}
|
||||
demo_image_b64 = load_image_as_base64(image_path)
|
||||
|
||||
vlm_response = client.inference.chat_completion(
|
||||
model_id="nvidia/vila",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": {
|
||||
"data": demo_image_b64,
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe what you see in this image in detail.",
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
print(f"VLM Response: {vlm_response.completion_message.content}")
|
||||
```
|
||||
|
|
|
@ -55,6 +55,10 @@ MODEL_ENTRIES = [
|
|||
"meta/llama-3.3-70b-instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nvidia/vila",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
# NeMo Retriever Text Embedding models -
|
||||
#
|
||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||
|
|
|
@ -118,10 +118,10 @@ class OllamaInferenceAdapter(
|
|||
|
||||
async def initialize(self) -> None:
|
||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
||||
health_response = await self.health()
|
||||
if health_response["status"] == HealthStatus.ERROR:
|
||||
r = await self.health()
|
||||
if r["status"] == HealthStatus.ERROR:
|
||||
logger.warning(
|
||||
"Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal"
|
||||
f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal"
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
|
@ -156,7 +156,7 @@ class OllamaInferenceAdapter(
|
|||
),
|
||||
Model(
|
||||
identifier="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text",
|
||||
provider_resource_id="nomic-embed-text:latest",
|
||||
provider_id=provider_id,
|
||||
metadata={
|
||||
"embedding_dimension": 768,
|
||||
|
|
|
@ -7,8 +7,8 @@
|
|||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from ibm_watson_machine_learning.foundation_models import Model
|
||||
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||
from ibm_watsonx_ai.foundation_models import Model
|
||||
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames as GenParams
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import heapq
|
||||
from typing import Any
|
||||
|
||||
import psycopg2
|
||||
|
@ -23,6 +24,9 @@ from llama_stack.apis.vector_io import (
|
|||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
|
@ -31,6 +35,7 @@ from llama_stack.providers.utils.memory.vector_store import (
|
|||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator, sanitize_collection_name
|
||||
|
||||
from .config import PGVectorVectorIOConfig
|
||||
|
||||
|
@ -72,25 +77,63 @@ def load_models(cur, cls):
|
|||
|
||||
|
||||
class PGVectorIndex(EmbeddingIndex):
|
||||
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||
self.conn = conn
|
||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||
self.kvstore = kvstore
|
||||
# reference: https://github.com/pgvector/pgvector?tab=readme-ov-file#querying
|
||||
PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION: dict[str, str] = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({dimension})
|
||||
def __init__(
|
||||
self,
|
||||
vector_db: VectorDB,
|
||||
dimension: int,
|
||||
conn: psycopg2.extensions.connection,
|
||||
kvstore: KVStore | None = None,
|
||||
distance_metric: str = "COSINE",
|
||||
):
|
||||
self.vector_db = vector_db
|
||||
self.dimension = dimension
|
||||
self.conn = conn
|
||||
self.kvstore = kvstore
|
||||
self.check_distance_metric_availability(distance_metric)
|
||||
self.distance_metric = distance_metric
|
||||
self.table_name = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
try:
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Sanitize the table name by replacing hyphens with underscores
|
||||
# SQL doesn't allow hyphens in table names, and vector_db.identifier may contain hyphens
|
||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||
sanitized_identifier = sanitize_collection_name(self.vector_db.identifier)
|
||||
self.table_name = f"vs_{sanitized_identifier}"
|
||||
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
||||
id TEXT PRIMARY KEY,
|
||||
document JSONB,
|
||||
embedding vector({self.dimension}),
|
||||
content_text TEXT,
|
||||
tokenized_content TSVECTOR
|
||||
)
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
# Create GIN index for full-text search performance
|
||||
cur.execute(
|
||||
f"""
|
||||
CREATE INDEX IF NOT EXISTS {self.table_name}_content_gin_idx
|
||||
ON {self.table_name} USING GIN(tokenized_content)
|
||||
"""
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}")
|
||||
raise RuntimeError(f"Error creating PGVectorIndex for vector_db: {self.vector_db.identifier}") from e
|
||||
|
||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||
assert len(chunks) == len(embeddings), (
|
||||
|
@ -99,29 +142,49 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
|
||||
values = []
|
||||
for i, chunk in enumerate(chunks):
|
||||
content_text = interleaved_content_as_str(chunk.content)
|
||||
values.append(
|
||||
(
|
||||
f"{chunk.chunk_id}",
|
||||
Json(chunk.model_dump()),
|
||||
embeddings[i].tolist(),
|
||||
content_text,
|
||||
content_text, # Pass content_text twice - once for content_text column, once for to_tsvector function. Eg. to_tsvector(content_text) = tokenized_content
|
||||
)
|
||||
)
|
||||
|
||||
query = sql.SQL(
|
||||
f"""
|
||||
INSERT INTO {self.table_name} (id, document, embedding)
|
||||
INSERT INTO {self.table_name} (id, document, embedding, content_text, tokenized_content)
|
||||
VALUES %s
|
||||
ON CONFLICT (id) DO UPDATE SET embedding = EXCLUDED.embedding, document = EXCLUDED.document
|
||||
ON CONFLICT (id) DO UPDATE SET
|
||||
embedding = EXCLUDED.embedding,
|
||||
document = EXCLUDED.document,
|
||||
content_text = EXCLUDED.content_text,
|
||||
tokenized_content = EXCLUDED.tokenized_content
|
||||
"""
|
||||
)
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector)")
|
||||
execute_values(cur, query, values, template="(%s, %s, %s::vector, %s, to_tsvector('english', %s))")
|
||||
|
||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||
"""
|
||||
Performs vector similarity search using PostgreSQL's search function. Default distance metric is COSINE.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
pgvector_search_function = self.get_pgvector_search_function()
|
||||
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, embedding <-> %s::vector AS distance
|
||||
SELECT document, embedding {pgvector_search_function} %s::vector AS distance
|
||||
FROM {self.table_name}
|
||||
ORDER BY distance
|
||||
LIMIT %s
|
||||
|
@ -147,7 +210,40 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
k: int,
|
||||
score_threshold: float,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Keyword search is not supported in PGVector")
|
||||
"""
|
||||
Performs keyword-based search using PostgreSQL's full-text search with ts_rank scoring.
|
||||
|
||||
Args:
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
# Use plainto_tsquery to handle user input safely and ts_rank for relevance scoring
|
||||
cur.execute(
|
||||
f"""
|
||||
SELECT document, ts_rank(tokenized_content, plainto_tsquery('english', %s)) AS score
|
||||
FROM {self.table_name}
|
||||
WHERE tokenized_content @@ plainto_tsquery('english', %s)
|
||||
ORDER BY score DESC
|
||||
LIMIT %s
|
||||
""",
|
||||
(query_string, query_string, k),
|
||||
)
|
||||
results = cur.fetchall()
|
||||
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc, score in results:
|
||||
if score < score_threshold:
|
||||
continue
|
||||
chunks.append(Chunk(**doc))
|
||||
scores.append(float(score))
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
self,
|
||||
|
@ -158,7 +254,59 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
reranker_type: str,
|
||||
reranker_params: dict[str, Any] | None = None,
|
||||
) -> QueryChunksResponse:
|
||||
raise NotImplementedError("Hybrid search is not supported in PGVector")
|
||||
"""
|
||||
Hybrid search combining vector similarity and keyword search using configurable reranking.
|
||||
|
||||
Args:
|
||||
embedding: The query embedding vector
|
||||
query_string: The text query for keyword search
|
||||
k: Number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
reranker_type: Type of reranker to use ("rrf" or "weighted")
|
||||
reranker_params: Parameters for the reranker
|
||||
|
||||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
# Get results from both search methods
|
||||
vector_response = await self.query_vector(embedding, k, score_threshold)
|
||||
keyword_response = await self.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
# Convert responses to score dictionaries using chunk_id
|
||||
vector_scores = {
|
||||
chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False)
|
||||
}
|
||||
keyword_scores = {
|
||||
chunk.chunk_id: score
|
||||
for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False)
|
||||
}
|
||||
|
||||
# Combine scores using the reranking utility
|
||||
combined_scores = WeightedInMemoryAggregator.combine_search_results(
|
||||
vector_scores, keyword_scores, reranker_type, reranker_params
|
||||
)
|
||||
|
||||
# Efficient top-k selection because it only tracks the k best candidates it's seen so far
|
||||
top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1])
|
||||
|
||||
# Filter by score threshold
|
||||
filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold]
|
||||
|
||||
# Create a map of chunk_id to chunk for both responses
|
||||
chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks}
|
||||
|
||||
# Use the map to look up chunks by their IDs
|
||||
chunks = []
|
||||
scores = []
|
||||
for doc_id, score in filtered_items:
|
||||
if doc_id in chunk_map:
|
||||
chunks.append(chunk_map[doc_id])
|
||||
scores.append(score)
|
||||
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self):
|
||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
|
@ -170,6 +318,25 @@ class PGVectorIndex(EmbeddingIndex):
|
|||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
cur.execute(f"DELETE FROM {self.table_name} WHERE id = ANY(%s)", (chunk_ids,))
|
||||
|
||||
def get_pgvector_search_function(self) -> str:
|
||||
return self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION[self.distance_metric]
|
||||
|
||||
def check_distance_metric_availability(self, distance_metric: str) -> None:
|
||||
"""Check if the distance metric is supported by PGVector.
|
||||
|
||||
Args:
|
||||
distance_metric: The distance metric to check
|
||||
|
||||
Raises:
|
||||
ValueError: If the distance metric is not supported
|
||||
"""
|
||||
if distance_metric not in self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION:
|
||||
supported_metrics = list(self.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION.keys())
|
||||
raise ValueError(
|
||||
f"Distance metric '{distance_metric}' is not supported by PGVector. "
|
||||
f"Supported metrics are: {', '.join(supported_metrics)}"
|
||||
)
|
||||
|
||||
|
||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
|
@ -185,8 +352,8 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||
|
@ -233,9 +400,13 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||
|
||||
# Create and cache the PGVector index table for the vector DB
|
||||
pgvector_index = PGVectorIndex(
|
||||
vector_db=vector_db, dimension=vector_db.embedding_dimension, conn=self.conn, kvstore=self.kvstore
|
||||
)
|
||||
await pgvector_index.initialize()
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||
index=pgvector_index,
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
|
@ -272,8 +443,15 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
|||
if vector_db_id in self.cache:
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
if self.vector_db_store is None:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
|
||||
if not vector_db:
|
||||
raise VectorStoreNotFoundError(vector_db_id)
|
||||
|
||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||
await index.initialize()
|
||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||
return self.cache[vector_db_id]
|
||||
|
||||
|
|
|
@ -294,12 +294,12 @@ class VectorDBWithIndex:
|
|||
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
|
||||
|
||||
if chunks_to_embed:
|
||||
resp = await self.inference_api.embeddings(
|
||||
resp = await self.inference_api.openai_embeddings(
|
||||
self.vector_db.embedding_model,
|
||||
[c.content for c in chunks_to_embed],
|
||||
)
|
||||
for c, embedding in zip(chunks_to_embed, resp.embeddings, strict=False):
|
||||
c.embedding = embedding
|
||||
for c, data in zip(chunks_to_embed, resp.data, strict=False):
|
||||
c.embedding = data.embedding
|
||||
|
||||
embeddings = np.array([c.embedding for c in chunks], dtype=np.float32)
|
||||
await self.index.add_chunks(chunks, embeddings)
|
||||
|
@ -334,8 +334,8 @@ class VectorDBWithIndex:
|
|||
if mode == "keyword":
|
||||
return await self.index.query_keyword(query_string, k, score_threshold)
|
||||
|
||||
embeddings_response = await self.inference_api.embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||
embeddings_response = await self.inference_api.openai_embeddings(self.vector_db.embedding_model, [query_string])
|
||||
query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32)
|
||||
if mode == "hybrid":
|
||||
return await self.index.query_hybrid(
|
||||
query_vector, query_string, k, score_threshold, reranker_type, reranker_params
|
||||
|
|
|
@ -23,6 +23,7 @@ from sqlalchemy import (
|
|||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.asyncio.engine import AsyncEngine
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -43,6 +44,30 @@ TYPE_MAPPING: dict[ColumnType, Any] = {
|
|||
}
|
||||
|
||||
|
||||
def _build_where_expr(column: ColumnElement, value: Any) -> ColumnElement:
|
||||
"""Return a SQLAlchemy expression for a where condition.
|
||||
|
||||
`value` may be a simple scalar (equality) or a mapping like {">": 123}.
|
||||
The returned expression is a SQLAlchemy ColumnElement usable in query.where(...).
|
||||
"""
|
||||
if isinstance(value, Mapping):
|
||||
if len(value) != 1:
|
||||
raise ValueError(f"Operator mapping must have a single operator, got: {value}")
|
||||
op, operand = next(iter(value.items()))
|
||||
if op == "==" or op == "=":
|
||||
return column == operand
|
||||
if op == ">":
|
||||
return column > operand
|
||||
if op == "<":
|
||||
return column < operand
|
||||
if op == ">=":
|
||||
return column >= operand
|
||||
if op == "<=":
|
||||
return column <= operand
|
||||
raise ValueError(f"Unsupported operator '{op}' in where mapping")
|
||||
return column == value
|
||||
|
||||
|
||||
class SqlAlchemySqlStoreImpl(SqlStore):
|
||||
def __init__(self, config: SqlAlchemySqlStoreConfig):
|
||||
self.config = config
|
||||
|
@ -111,7 +136,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
|
||||
if where:
|
||||
for key, value in where.items():
|
||||
query = query.where(table_obj.c[key] == value)
|
||||
query = query.where(_build_where_expr(table_obj.c[key], value))
|
||||
|
||||
if where_sql:
|
||||
query = query.where(text(where_sql))
|
||||
|
@ -222,7 +247,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].update()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt, data)
|
||||
await session.commit()
|
||||
|
||||
|
@ -233,7 +258,7 @@ class SqlAlchemySqlStoreImpl(SqlStore):
|
|||
async with self.async_session() as session:
|
||||
stmt = self.metadata.tables[table].delete()
|
||||
for key, value in where.items():
|
||||
stmt = stmt.where(self.metadata.tables[table].c[key] == value)
|
||||
stmt = stmt.where(_build_where_expr(self.metadata.tables[table].c[key], value))
|
||||
await session.execute(stmt)
|
||||
await session.commit()
|
||||
|
||||
|
|
|
@ -37,3 +37,122 @@ def sanitize_collection_name(name: str, weaviate_format=False) -> str:
|
|||
else:
|
||||
s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name))
|
||||
return s
|
||||
|
||||
|
||||
class WeightedInMemoryAggregator:
|
||||
@staticmethod
|
||||
def _normalize_scores(scores: dict[str, float]) -> dict[str, float]:
|
||||
"""
|
||||
Normalize scores to 0-1 range using min-max normalization.
|
||||
|
||||
Args:
|
||||
scores: dictionary of scores with document IDs as keys and scores as values
|
||||
|
||||
Returns:
|
||||
Normalized scores with document IDs as keys and normalized scores as values
|
||||
"""
|
||||
if not scores:
|
||||
return {}
|
||||
min_score, max_score = min(scores.values()), max(scores.values())
|
||||
score_range = max_score - min_score
|
||||
if score_range > 0:
|
||||
return {doc_id: (score - min_score) / score_range for doc_id, score in scores.items()}
|
||||
return dict.fromkeys(scores, 1.0)
|
||||
|
||||
@staticmethod
|
||||
def weighted_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
alpha: float = 0.5,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via weighted average of scores.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
alpha: weight factor between 0 and 1 (default: 0.5)
|
||||
0 = keyword only, 1 = vector only, 0.5 = equal weight
|
||||
|
||||
Returns:
|
||||
All unique document IDs with weighted combined scores
|
||||
"""
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
normalized_vector_scores = WeightedInMemoryAggregator._normalize_scores(vector_scores)
|
||||
normalized_keyword_scores = WeightedInMemoryAggregator._normalize_scores(keyword_scores)
|
||||
|
||||
# Weighted formula: score = (1-alpha) * keyword_score + alpha * vector_score
|
||||
# alpha=0 means keyword only, alpha=1 means vector only
|
||||
return {
|
||||
doc_id: ((1 - alpha) * normalized_keyword_scores.get(doc_id, 0.0))
|
||||
+ (alpha * normalized_vector_scores.get(doc_id, 0.0))
|
||||
for doc_id in all_ids
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def rrf_rerank(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
impact_factor: float = 60.0,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Rerank via Reciprocal Rank Fusion.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
impact_factor: impact factor for RRF (default: 60.0)
|
||||
|
||||
Returns:
|
||||
All unique document IDs with RRF combined scores
|
||||
"""
|
||||
|
||||
# Convert scores to ranks
|
||||
vector_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(vector_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
keyword_ranks = {
|
||||
doc_id: i + 1
|
||||
for i, (doc_id, _) in enumerate(sorted(keyword_scores.items(), key=lambda x: x[1], reverse=True))
|
||||
}
|
||||
|
||||
all_ids = set(vector_scores.keys()) | set(keyword_scores.keys())
|
||||
rrf_scores = {}
|
||||
for doc_id in all_ids:
|
||||
vector_rank = vector_ranks.get(doc_id, float("inf"))
|
||||
keyword_rank = keyword_ranks.get(doc_id, float("inf"))
|
||||
|
||||
# RRF formula: score = 1/(k + r) where k is impact_factor (default: 60.0) and r is the rank
|
||||
rrf_scores[doc_id] = (1.0 / (impact_factor + vector_rank)) + (1.0 / (impact_factor + keyword_rank))
|
||||
return rrf_scores
|
||||
|
||||
@staticmethod
|
||||
def combine_search_results(
|
||||
vector_scores: dict[str, float],
|
||||
keyword_scores: dict[str, float],
|
||||
reranker_type: str = "rrf",
|
||||
reranker_params: dict[str, float] | None = None,
|
||||
) -> dict[str, float]:
|
||||
"""
|
||||
Combine vector and keyword search results using specified reranking strategy.
|
||||
|
||||
Args:
|
||||
vector_scores: scores from vector search
|
||||
keyword_scores: scores from keyword search
|
||||
reranker_type: type of reranker to use (default: RERANKER_TYPE_RRF)
|
||||
reranker_params: parameters for the reranker
|
||||
|
||||
Returns:
|
||||
All unique document IDs with combined scores
|
||||
"""
|
||||
if reranker_params is None:
|
||||
reranker_params = {}
|
||||
|
||||
if reranker_type == "weighted":
|
||||
alpha = reranker_params.get("alpha", 0.5)
|
||||
return WeightedInMemoryAggregator.weighted_rerank(vector_scores, keyword_scores, alpha)
|
||||
else:
|
||||
# Default to RRF for None, RRF, or any unknown types
|
||||
impact_factor = reranker_params.get("impact_factor", 60.0)
|
||||
return WeightedInMemoryAggregator.rrf_rerank(vector_scores, keyword_scores, impact_factor)
|
||||
|
|
|
@ -9,7 +9,6 @@ from __future__ import annotations # for forward references
|
|||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import StrEnum
|
||||
|
@ -31,6 +30,9 @@ from openai.types.completion_choice import CompletionChoice
|
|||
CompletionChoice.model_fields["finish_reason"].annotation = Literal["stop", "length", "content_filter"] | None
|
||||
CompletionChoice.model_rebuild()
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
DEFAULT_STORAGE_DIR = REPO_ROOT / "tests/integration/recordings"
|
||||
|
||||
|
||||
class InferenceMode(StrEnum):
|
||||
LIVE = "live"
|
||||
|
@ -52,7 +54,7 @@ def normalize_request(method: str, url: str, headers: dict[str, Any], body: dict
|
|||
|
||||
|
||||
def get_inference_mode() -> InferenceMode:
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "live").lower())
|
||||
return InferenceMode(os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE", "replay").lower())
|
||||
|
||||
|
||||
def setup_inference_recording():
|
||||
|
@ -61,28 +63,18 @@ def setup_inference_recording():
|
|||
to increase their reliability and reduce reliance on expensive, external services.
|
||||
|
||||
Currently, this is only supported for OpenAI and Ollama clients. These should cover the vast majority of use cases.
|
||||
Calls to the /models endpoint are not currently trapped. We probably need to add support for this.
|
||||
|
||||
Two environment variables are required:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in.
|
||||
Two environment variables are supported:
|
||||
- LLAMA_STACK_TEST_INFERENCE_MODE: The mode to run in. Must be 'live', 'record', or 'replay'. Default is 'replay'.
|
||||
- LLAMA_STACK_TEST_RECORDING_DIR: The directory to store the recordings in. Default is 'tests/integration/recordings'.
|
||||
|
||||
The recordings are stored in a SQLite database and a JSON file for each request. The SQLite database is used to
|
||||
quickly find the correct recording for a given request. The JSON files are used to store the request and response
|
||||
bodies.
|
||||
The recordings are stored as JSON files.
|
||||
"""
|
||||
mode = get_inference_mode()
|
||||
|
||||
if mode not in InferenceMode:
|
||||
raise ValueError(f"Invalid LLAMA_STACK_TEST_INFERENCE_MODE: {mode}. Must be 'live', 'record', or 'replay'")
|
||||
|
||||
if mode == InferenceMode.LIVE:
|
||||
return None
|
||||
|
||||
if "LLAMA_STACK_TEST_RECORDING_DIR" not in os.environ:
|
||||
raise ValueError("LLAMA_STACK_TEST_RECORDING_DIR must be set for recording or replaying")
|
||||
storage_dir = os.environ["LLAMA_STACK_TEST_RECORDING_DIR"]
|
||||
|
||||
storage_dir = os.environ.get("LLAMA_STACK_TEST_RECORDING_DIR", DEFAULT_STORAGE_DIR)
|
||||
return inference_recording(mode=mode, storage_dir=storage_dir)
|
||||
|
||||
|
||||
|
@ -125,33 +117,18 @@ class ResponseStorage:
|
|||
def __init__(self, test_dir: Path):
|
||||
self.test_dir = test_dir
|
||||
self.responses_dir = self.test_dir / "responses"
|
||||
self.db_path = self.test_dir / "index.sqlite"
|
||||
|
||||
self._ensure_directories()
|
||||
self._init_database()
|
||||
|
||||
def _ensure_directories(self):
|
||||
self.test_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.responses_dir.mkdir(exist_ok=True)
|
||||
|
||||
def _init_database(self):
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS recordings (
|
||||
request_hash TEXT PRIMARY KEY,
|
||||
response_file TEXT,
|
||||
endpoint TEXT,
|
||||
model TEXT,
|
||||
timestamp TEXT,
|
||||
is_streaming BOOLEAN
|
||||
)
|
||||
""")
|
||||
|
||||
def store_recording(self, request_hash: str, request: dict[str, Any], response: dict[str, Any]):
|
||||
"""Store a request/response pair."""
|
||||
# Generate unique response filename
|
||||
response_file = f"{request_hash[:12]}.json"
|
||||
response_path = self.responses_dir / response_file
|
||||
short_hash = request_hash[:12]
|
||||
response_file = f"{short_hash}.json"
|
||||
|
||||
# Serialize response body if needed
|
||||
serialized_response = dict(response)
|
||||
|
@ -163,58 +140,107 @@ class ResponseStorage:
|
|||
# Handle single response
|
||||
serialized_response["body"] = _serialize_response(serialized_response["body"])
|
||||
|
||||
# If this is an Ollama /api/tags recording, include models digest in filename to distinguish variants
|
||||
endpoint = request.get("endpoint")
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
digest = _model_identifiers_digest(endpoint, response)
|
||||
response_file = f"models-{short_hash}-{digest}.json"
|
||||
|
||||
response_path = self.responses_dir / response_file
|
||||
|
||||
# Save response to JSON file
|
||||
with open(response_path, "w") as f:
|
||||
json.dump({"request": request, "response": serialized_response}, f, indent=2)
|
||||
f.write("\n")
|
||||
f.flush()
|
||||
|
||||
# Update SQLite index
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO recordings
|
||||
(request_hash, response_file, endpoint, model, timestamp, is_streaming)
|
||||
VALUES (?, ?, ?, ?, datetime('now'), ?)
|
||||
""",
|
||||
(
|
||||
request_hash,
|
||||
response_file,
|
||||
request.get("endpoint", ""),
|
||||
request.get("model", ""),
|
||||
response.get("is_streaming", False),
|
||||
),
|
||||
)
|
||||
|
||||
def find_recording(self, request_hash: str) -> dict[str, Any] | None:
|
||||
"""Find a recorded response by request hash."""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
result = conn.execute(
|
||||
"SELECT response_file FROM recordings WHERE request_hash = ?", (request_hash,)
|
||||
).fetchone()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
response_file = result[0]
|
||||
response_file = f"{request_hash[:12]}.json"
|
||||
response_path = self.responses_dir / response_file
|
||||
|
||||
if not response_path.exists():
|
||||
return None
|
||||
|
||||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
return _recording_from_file(response_path)
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
# Handle streaming responses
|
||||
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||
def _model_list_responses(self, short_hash: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
for path in self.responses_dir.glob(f"models-{short_hash}-*.json"):
|
||||
data = _recording_from_file(path)
|
||||
results.append(data)
|
||||
return results
|
||||
|
||||
|
||||
def _recording_from_file(response_path) -> dict[str, Any]:
|
||||
with open(response_path) as f:
|
||||
data = json.load(f)
|
||||
|
||||
# Deserialize response body if needed
|
||||
if "response" in data and "body" in data["response"]:
|
||||
if isinstance(data["response"]["body"], list):
|
||||
# Handle streaming responses
|
||||
data["response"]["body"] = [_deserialize_response(chunk) for chunk in data["response"]["body"]]
|
||||
else:
|
||||
# Handle single response
|
||||
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||
|
||||
return cast(dict[str, Any], data)
|
||||
|
||||
|
||||
def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
|
||||
def _extract_model_identifiers():
|
||||
"""Extract a stable set of identifiers for model-list endpoints.
|
||||
|
||||
Supported endpoints:
|
||||
- '/api/tags' (Ollama): response body has 'models': [ { name/model/digest/id/... }, ... ]
|
||||
- '/v1/models' (OpenAI): response body has 'data': [ { id: ... }, ... ]
|
||||
Returns a list of unique identifiers or None if structure doesn't match.
|
||||
"""
|
||||
body = response["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.get("models")
|
||||
idents = [m.model for m in items]
|
||||
else:
|
||||
items = body.get("data")
|
||||
idents = [m.id for m in items]
|
||||
return sorted(set(idents))
|
||||
|
||||
identifiers = _extract_model_identifiers()
|
||||
return hashlib.sha1(("|".join(identifiers)).encode("utf-8")).hexdigest()[:8]
|
||||
|
||||
|
||||
def _combine_model_list_responses(endpoint: str, records: list[dict[str, Any]]) -> dict[str, Any] | None:
|
||||
"""Return a single, unioned recording for supported model-list endpoints."""
|
||||
seen: dict[str, dict[str, Any]] = {}
|
||||
for rec in records:
|
||||
body = rec["response"]["body"]
|
||||
if endpoint == "/api/tags":
|
||||
items = body.models
|
||||
elif endpoint == "/v1/models":
|
||||
items = body.data
|
||||
else:
|
||||
items = []
|
||||
|
||||
for m in items:
|
||||
if endpoint == "/v1/models":
|
||||
key = m.id
|
||||
else:
|
||||
# Handle single response
|
||||
data["response"]["body"] = _deserialize_response(data["response"]["body"])
|
||||
key = m.model
|
||||
seen[key] = m
|
||||
|
||||
return cast(dict[str, Any], data)
|
||||
ordered = [seen[k] for k in sorted(seen.keys())]
|
||||
canonical = records[0]
|
||||
canonical_req = canonical.get("request", {})
|
||||
if isinstance(canonical_req, dict):
|
||||
canonical_req["endpoint"] = endpoint
|
||||
if endpoint == "/v1/models":
|
||||
body = {"data": ordered, "object": "list"}
|
||||
else:
|
||||
from ollama import ListResponse
|
||||
|
||||
body = ListResponse(models=ordered)
|
||||
return {"request": canonical_req, "response": {"body": body, "is_streaming": False}}
|
||||
|
||||
|
||||
async def _patched_inference_method(original_method, self, client_type, endpoint, *args, **kwargs):
|
||||
|
@ -236,8 +262,6 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
raise ValueError(f"Unknown client type: {client_type}")
|
||||
|
||||
url = base_url.rstrip("/") + endpoint
|
||||
|
||||
# Normalize request for matching
|
||||
method = "POST"
|
||||
headers = {}
|
||||
body = kwargs
|
||||
|
@ -245,7 +269,12 @@ async def _patched_inference_method(original_method, self, client_type, endpoint
|
|||
request_hash = normalize_request(method, url, headers, body)
|
||||
|
||||
if _current_mode == InferenceMode.REPLAY:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
# Special handling for model-list endpoints: return union of all responses
|
||||
if endpoint in ("/api/tags", "/v1/models"):
|
||||
records = _current_storage._model_list_responses(request_hash[:12])
|
||||
recording = _combine_model_list_responses(endpoint, records)
|
||||
else:
|
||||
recording = _current_storage.find_recording(request_hash)
|
||||
if recording:
|
||||
response_body = recording["response"]["body"]
|
||||
|
||||
|
@ -315,12 +344,14 @@ def patch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Store original methods for both OpenAI and Ollama clients
|
||||
_original_methods = {
|
||||
"chat_completions_create": AsyncChatCompletions.create,
|
||||
"completions_create": AsyncCompletions.create,
|
||||
"embeddings_create": AsyncEmbeddings.create,
|
||||
"models_list": AsyncModels.list,
|
||||
"ollama_generate": OllamaAsyncClient.generate,
|
||||
"ollama_chat": OllamaAsyncClient.chat,
|
||||
"ollama_embed": OllamaAsyncClient.embed,
|
||||
|
@ -345,10 +376,16 @@ def patch_inference_clients():
|
|||
_original_methods["embeddings_create"], self, "openai", "/v1/embeddings", *args, **kwargs
|
||||
)
|
||||
|
||||
async def patched_models_list(self, *args, **kwargs):
|
||||
return await _patched_inference_method(
|
||||
_original_methods["models_list"], self, "openai", "/v1/models", *args, **kwargs
|
||||
)
|
||||
|
||||
# Apply OpenAI patches
|
||||
AsyncChatCompletions.create = patched_chat_completions_create
|
||||
AsyncCompletions.create = patched_completions_create
|
||||
AsyncEmbeddings.create = patched_embeddings_create
|
||||
AsyncModels.list = patched_models_list
|
||||
|
||||
# Create patched methods for Ollama client
|
||||
async def patched_ollama_generate(self, *args, **kwargs):
|
||||
|
@ -402,11 +439,13 @@ def unpatch_inference_clients():
|
|||
from openai.resources.chat.completions import AsyncCompletions as AsyncChatCompletions
|
||||
from openai.resources.completions import AsyncCompletions
|
||||
from openai.resources.embeddings import AsyncEmbeddings
|
||||
from openai.resources.models import AsyncModels
|
||||
|
||||
# Restore OpenAI client methods
|
||||
AsyncChatCompletions.create = _original_methods["chat_completions_create"]
|
||||
AsyncCompletions.create = _original_methods["completions_create"]
|
||||
AsyncEmbeddings.create = _original_methods["embeddings_create"]
|
||||
AsyncModels.list = _original_methods["models_list"]
|
||||
|
||||
# Restore Ollama client methods if they were patched
|
||||
OllamaAsyncClient.generate = _original_methods["ollama_generate"]
|
||||
|
@ -420,16 +459,10 @@ def unpatch_inference_clients():
|
|||
|
||||
|
||||
@contextmanager
|
||||
def inference_recording(mode: str = "live", storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
def inference_recording(mode: str, storage_dir: str | Path | None = None) -> Generator[None, None, None]:
|
||||
"""Context manager for inference recording/replaying."""
|
||||
global _current_mode, _current_storage
|
||||
|
||||
# Set defaults
|
||||
if storage_dir is None:
|
||||
storage_dir_path = Path.home() / ".llama" / "recordings"
|
||||
else:
|
||||
storage_dir_path = Path(storage_dir)
|
||||
|
||||
# Store previous state
|
||||
prev_mode = _current_mode
|
||||
prev_storage = _current_storage
|
||||
|
@ -438,7 +471,9 @@ def inference_recording(mode: str = "live", storage_dir: str | Path | None = Non
|
|||
_current_mode = mode
|
||||
|
||||
if mode in ["record", "replay"]:
|
||||
_current_storage = ResponseStorage(storage_dir_path)
|
||||
if storage_dir is None:
|
||||
raise ValueError("storage_dir is required for record and replay modes")
|
||||
_current_storage = ResponseStorage(Path(storage_dir))
|
||||
patch_inference_clients()
|
||||
|
||||
yield
|
||||
|
|
610
llama_stack/ui/app/chat-playground/chunk-processor.test.tsx
Normal file
610
llama_stack/ui/app/chat-playground/chunk-processor.test.tsx
Normal file
|
@ -0,0 +1,610 @@
|
|||
import { describe, test, expect } from "@jest/globals";
|
||||
|
||||
// Extract the exact processChunk function implementation for testing
|
||||
function createProcessChunk() {
|
||||
return (chunk: unknown): { text: string | null; isToolCall: boolean } => {
|
||||
const chunkObj = chunk as Record<string, unknown>;
|
||||
|
||||
// Helper function to check if content contains function call JSON
|
||||
const containsToolCall = (content: string): boolean => {
|
||||
return (
|
||||
content.includes('"type": "function"') ||
|
||||
content.includes('"name": "knowledge_search"') ||
|
||||
content.includes('"parameters":') ||
|
||||
!!content.match(/\{"type":\s*"function".*?\}/)
|
||||
);
|
||||
};
|
||||
|
||||
// Check if this chunk contains a tool call (function call)
|
||||
let isToolCall = false;
|
||||
|
||||
// Check direct chunk content if it's a string
|
||||
if (typeof chunk === "string") {
|
||||
isToolCall = containsToolCall(chunk);
|
||||
}
|
||||
|
||||
// Check delta structures
|
||||
if (
|
||||
chunkObj?.delta &&
|
||||
typeof chunkObj.delta === "object" &&
|
||||
chunkObj.delta !== null
|
||||
) {
|
||||
const delta = chunkObj.delta as Record<string, unknown>;
|
||||
if ("tool_calls" in delta) {
|
||||
isToolCall = true;
|
||||
}
|
||||
if (typeof delta.text === "string") {
|
||||
if (containsToolCall(delta.text)) {
|
||||
isToolCall = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check event structures
|
||||
if (
|
||||
chunkObj?.event &&
|
||||
typeof chunkObj.event === "object" &&
|
||||
chunkObj.event !== null
|
||||
) {
|
||||
const event = chunkObj.event as Record<string, unknown>;
|
||||
|
||||
// Check event payload
|
||||
if (
|
||||
event?.payload &&
|
||||
typeof event.payload === "object" &&
|
||||
event.payload !== null
|
||||
) {
|
||||
const payload = event.payload as Record<string, unknown>;
|
||||
if (typeof payload.content === "string") {
|
||||
if (containsToolCall(payload.content)) {
|
||||
isToolCall = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Check payload delta
|
||||
if (
|
||||
payload?.delta &&
|
||||
typeof payload.delta === "object" &&
|
||||
payload.delta !== null
|
||||
) {
|
||||
const delta = payload.delta as Record<string, unknown>;
|
||||
if (typeof delta.text === "string") {
|
||||
if (containsToolCall(delta.text)) {
|
||||
isToolCall = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check event delta
|
||||
if (
|
||||
event?.delta &&
|
||||
typeof event.delta === "object" &&
|
||||
event.delta !== null
|
||||
) {
|
||||
const delta = event.delta as Record<string, unknown>;
|
||||
if (typeof delta.text === "string") {
|
||||
if (containsToolCall(delta.text)) {
|
||||
isToolCall = true;
|
||||
}
|
||||
}
|
||||
if (typeof delta.content === "string") {
|
||||
if (containsToolCall(delta.content)) {
|
||||
isToolCall = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// if it's a tool call, skip it (don't display in chat)
|
||||
if (isToolCall) {
|
||||
return { text: null, isToolCall: true };
|
||||
}
|
||||
|
||||
// Extract text content from various chunk formats
|
||||
let text: string | null = null;
|
||||
|
||||
// Helper function to extract clean text content, filtering out function calls
|
||||
const extractCleanText = (content: string): string | null => {
|
||||
if (containsToolCall(content)) {
|
||||
try {
|
||||
// Try to parse and extract non-function call parts
|
||||
const jsonMatch = content.match(
|
||||
/\{"type":\s*"function"[^}]*\}[^}]*\}/
|
||||
);
|
||||
if (jsonMatch) {
|
||||
const jsonPart = jsonMatch[0];
|
||||
const parsedJson = JSON.parse(jsonPart);
|
||||
|
||||
// If it's a function call, extract text after JSON
|
||||
if (parsedJson.type === "function") {
|
||||
const textAfterJson = content
|
||||
.substring(content.indexOf(jsonPart) + jsonPart.length)
|
||||
.trim();
|
||||
return textAfterJson || null;
|
||||
}
|
||||
}
|
||||
// If we can't parse it properly, skip the whole thing
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return content;
|
||||
};
|
||||
|
||||
// Try direct delta text
|
||||
if (
|
||||
chunkObj?.delta &&
|
||||
typeof chunkObj.delta === "object" &&
|
||||
chunkObj.delta !== null
|
||||
) {
|
||||
const delta = chunkObj.delta as Record<string, unknown>;
|
||||
if (typeof delta.text === "string") {
|
||||
text = extractCleanText(delta.text);
|
||||
}
|
||||
}
|
||||
|
||||
// Try event structures
|
||||
if (
|
||||
!text &&
|
||||
chunkObj?.event &&
|
||||
typeof chunkObj.event === "object" &&
|
||||
chunkObj.event !== null
|
||||
) {
|
||||
const event = chunkObj.event as Record<string, unknown>;
|
||||
|
||||
// Try event payload content
|
||||
if (
|
||||
event?.payload &&
|
||||
typeof event.payload === "object" &&
|
||||
event.payload !== null
|
||||
) {
|
||||
const payload = event.payload as Record<string, unknown>;
|
||||
|
||||
// Try direct payload content
|
||||
if (typeof payload.content === "string") {
|
||||
text = extractCleanText(payload.content);
|
||||
}
|
||||
|
||||
// Try turn_complete event structure: payload.turn.output_message.content
|
||||
if (
|
||||
!text &&
|
||||
payload?.turn &&
|
||||
typeof payload.turn === "object" &&
|
||||
payload.turn !== null
|
||||
) {
|
||||
const turn = payload.turn as Record<string, unknown>;
|
||||
if (
|
||||
turn?.output_message &&
|
||||
typeof turn.output_message === "object" &&
|
||||
turn.output_message !== null
|
||||
) {
|
||||
const outputMessage = turn.output_message as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
if (typeof outputMessage.content === "string") {
|
||||
text = extractCleanText(outputMessage.content);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to model_response in steps if no output_message
|
||||
if (
|
||||
!text &&
|
||||
turn?.steps &&
|
||||
Array.isArray(turn.steps) &&
|
||||
turn.steps.length > 0
|
||||
) {
|
||||
for (const step of turn.steps) {
|
||||
if (step && typeof step === "object" && step !== null) {
|
||||
const stepObj = step as Record<string, unknown>;
|
||||
if (
|
||||
stepObj?.model_response &&
|
||||
typeof stepObj.model_response === "object" &&
|
||||
stepObj.model_response !== null
|
||||
) {
|
||||
const modelResponse = stepObj.model_response as Record<
|
||||
string,
|
||||
unknown
|
||||
>;
|
||||
if (typeof modelResponse.content === "string") {
|
||||
text = extractCleanText(modelResponse.content);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try payload delta
|
||||
if (
|
||||
!text &&
|
||||
payload?.delta &&
|
||||
typeof payload.delta === "object" &&
|
||||
payload.delta !== null
|
||||
) {
|
||||
const delta = payload.delta as Record<string, unknown>;
|
||||
if (typeof delta.text === "string") {
|
||||
text = extractCleanText(delta.text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try event delta
|
||||
if (
|
||||
!text &&
|
||||
event?.delta &&
|
||||
typeof event.delta === "object" &&
|
||||
event.delta !== null
|
||||
) {
|
||||
const delta = event.delta as Record<string, unknown>;
|
||||
if (typeof delta.text === "string") {
|
||||
text = extractCleanText(delta.text);
|
||||
}
|
||||
if (!text && typeof delta.content === "string") {
|
||||
text = extractCleanText(delta.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try choices structure (ChatML format)
|
||||
if (
|
||||
!text &&
|
||||
chunkObj?.choices &&
|
||||
Array.isArray(chunkObj.choices) &&
|
||||
chunkObj.choices.length > 0
|
||||
) {
|
||||
const choice = chunkObj.choices[0] as Record<string, unknown>;
|
||||
if (
|
||||
choice?.delta &&
|
||||
typeof choice.delta === "object" &&
|
||||
choice.delta !== null
|
||||
) {
|
||||
const delta = choice.delta as Record<string, unknown>;
|
||||
if (typeof delta.content === "string") {
|
||||
text = extractCleanText(delta.content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try direct string content
|
||||
if (!text && typeof chunk === "string") {
|
||||
text = extractCleanText(chunk);
|
||||
}
|
||||
|
||||
return { text, isToolCall: false };
|
||||
};
|
||||
}
|
||||
|
||||
describe("Chunk Processor", () => {
|
||||
const processChunk = createProcessChunk();
|
||||
|
||||
describe("Real Event Structures", () => {
|
||||
test("handles turn_complete event with cancellation policy response", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
payload: {
|
||||
event_type: "turn_complete",
|
||||
turn: {
|
||||
turn_id: "50a2d6b7-49ed-4d1e-b1c2-6d68b3f726db",
|
||||
session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3",
|
||||
input_messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "nice, what's the cancellation policy?",
|
||||
context: null,
|
||||
},
|
||||
],
|
||||
steps: [
|
||||
{
|
||||
turn_id: "50a2d6b7-49ed-4d1e-b1c2-6d68b3f726db",
|
||||
step_id: "54074310-af42-414c-9ffe-fba5b2ead0ad",
|
||||
started_at: "2025-08-27T18:15:25.870703Z",
|
||||
completed_at: "2025-08-27T18:15:51.288993Z",
|
||||
step_type: "inference",
|
||||
model_response: {
|
||||
role: "assistant",
|
||||
content:
|
||||
"According to the search results, the cancellation policy for Red Hat Summit is as follows:\n\n* Cancellations must be received by 5 PM EDT on April 18, 2025 for a 50% refund of the registration fee.\n* No refunds will be given for cancellations received after 5 PM EDT on April 18, 2025.\n* Cancellation of travel reservations and hotel reservations are the responsibility of the registrant.",
|
||||
stop_reason: "end_of_turn",
|
||||
tool_calls: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
output_message: {
|
||||
role: "assistant",
|
||||
content:
|
||||
"According to the search results, the cancellation policy for Red Hat Summit is as follows:\n\n* Cancellations must be received by 5 PM EDT on April 18, 2025 for a 50% refund of the registration fee.\n* No refunds will be given for cancellations received after 5 PM EDT on April 18, 2025.\n* Cancellation of travel reservations and hotel reservations are the responsibility of the registrant.",
|
||||
stop_reason: "end_of_turn",
|
||||
tool_calls: [],
|
||||
},
|
||||
output_attachments: [],
|
||||
started_at: "2025-08-27T18:15:25.868548Z",
|
||||
completed_at: "2025-08-27T18:15:51.289262Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toContain(
|
||||
"According to the search results, the cancellation policy for Red Hat Summit is as follows:"
|
||||
);
|
||||
expect(result.text).toContain("5 PM EDT on April 18, 2025");
|
||||
});
|
||||
|
||||
test("handles turn_complete event with address response", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
payload: {
|
||||
event_type: "turn_complete",
|
||||
turn: {
|
||||
turn_id: "2f4a1520-8ecc-4cb7-bb7b-886939e042b0",
|
||||
session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3",
|
||||
input_messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "what's francisco's address",
|
||||
context: null,
|
||||
},
|
||||
],
|
||||
steps: [
|
||||
{
|
||||
turn_id: "2f4a1520-8ecc-4cb7-bb7b-886939e042b0",
|
||||
step_id: "c13dd277-1acb-4419-8fbf-d5e2f45392ea",
|
||||
started_at: "2025-08-27T18:14:52.558761Z",
|
||||
completed_at: "2025-08-27T18:15:11.306032Z",
|
||||
step_type: "inference",
|
||||
model_response: {
|
||||
role: "assistant",
|
||||
content:
|
||||
"Francisco Arceo's address is:\n\nRed Hat\nUnited States\n17 Primrose Ln \nBasking Ridge New Jersey 07920",
|
||||
stop_reason: "end_of_turn",
|
||||
tool_calls: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
output_message: {
|
||||
role: "assistant",
|
||||
content:
|
||||
"Francisco Arceo's address is:\n\nRed Hat\nUnited States\n17 Primrose Ln \nBasking Ridge New Jersey 07920",
|
||||
stop_reason: "end_of_turn",
|
||||
tool_calls: [],
|
||||
},
|
||||
output_attachments: [],
|
||||
started_at: "2025-08-27T18:14:52.553707Z",
|
||||
completed_at: "2025-08-27T18:15:11.306729Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toContain("Francisco Arceo's address is:");
|
||||
expect(result.text).toContain("17 Primrose Ln");
|
||||
expect(result.text).toContain("Basking Ridge New Jersey 07920");
|
||||
});
|
||||
|
||||
test("handles turn_complete event with ticket cost response", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
payload: {
|
||||
event_type: "turn_complete",
|
||||
turn: {
|
||||
turn_id: "7ef244a3-efee-42ca-a9c8-942865251002",
|
||||
session_id: "e7f62b8e-518c-4450-82df-e65fe49f27a3",
|
||||
input_messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: "what was the ticket cost for summit?",
|
||||
context: null,
|
||||
},
|
||||
],
|
||||
steps: [
|
||||
{
|
||||
turn_id: "7ef244a3-efee-42ca-a9c8-942865251002",
|
||||
step_id: "7651dda0-315a-472d-b1c1-3c2725f55bc5",
|
||||
started_at: "2025-08-27T18:14:21.710611Z",
|
||||
completed_at: "2025-08-27T18:14:39.706452Z",
|
||||
step_type: "inference",
|
||||
model_response: {
|
||||
role: "assistant",
|
||||
content:
|
||||
"The ticket cost for the Red Hat Summit was $999.00 for a conference pass.",
|
||||
stop_reason: "end_of_turn",
|
||||
tool_calls: [],
|
||||
},
|
||||
},
|
||||
],
|
||||
output_message: {
|
||||
role: "assistant",
|
||||
content:
|
||||
"The ticket cost for the Red Hat Summit was $999.00 for a conference pass.",
|
||||
stop_reason: "end_of_turn",
|
||||
tool_calls: [],
|
||||
},
|
||||
output_attachments: [],
|
||||
started_at: "2025-08-27T18:14:21.705289Z",
|
||||
completed_at: "2025-08-27T18:14:39.706752Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe(
|
||||
"The ticket cost for the Red Hat Summit was $999.00 for a conference pass."
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Function Call Detection", () => {
|
||||
test("detects function calls in direct string chunks", () => {
|
||||
const chunk =
|
||||
'{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}}';
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(true);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
|
||||
test("detects function calls in event payload content", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
payload: {
|
||||
content:
|
||||
'{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}}',
|
||||
},
|
||||
},
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(true);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
|
||||
test("detects tool_calls in delta structure", () => {
|
||||
const chunk = {
|
||||
delta: {
|
||||
tool_calls: [{ function: { name: "knowledge_search" } }],
|
||||
},
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(true);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
|
||||
test("detects function call in mixed content but skips it", () => {
|
||||
const chunk =
|
||||
'{"type": "function", "name": "knowledge_search", "parameters": {"query": "test"}} Based on the search results, here is your answer.';
|
||||
const result = processChunk(chunk);
|
||||
// This is detected as a tool call and skipped entirely - the implementation prioritizes safety
|
||||
expect(result.isToolCall).toBe(true);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
});
|
||||
|
||||
describe("Text Extraction", () => {
|
||||
test("extracts text from direct string chunks", () => {
|
||||
const chunk = "Hello, this is a normal response.";
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe("Hello, this is a normal response.");
|
||||
});
|
||||
|
||||
test("extracts text from delta structure", () => {
|
||||
const chunk = {
|
||||
delta: {
|
||||
text: "Hello, this is a normal response.",
|
||||
},
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe("Hello, this is a normal response.");
|
||||
});
|
||||
|
||||
test("extracts text from choices structure", () => {
|
||||
const chunk = {
|
||||
choices: [
|
||||
{
|
||||
delta: {
|
||||
content: "Hello, this is a normal response.",
|
||||
},
|
||||
},
|
||||
],
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe("Hello, this is a normal response.");
|
||||
});
|
||||
|
||||
test("prioritizes output_message over model_response in turn structure", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
payload: {
|
||||
turn: {
|
||||
steps: [
|
||||
{
|
||||
model_response: {
|
||||
content: "Model response content.",
|
||||
},
|
||||
},
|
||||
],
|
||||
output_message: {
|
||||
content: "Final output message content.",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe("Final output message content.");
|
||||
});
|
||||
|
||||
test("falls back to model_response when no output_message", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
payload: {
|
||||
turn: {
|
||||
steps: [
|
||||
{
|
||||
model_response: {
|
||||
content: "This is from the model response.",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe("This is from the model response.");
|
||||
});
|
||||
});
|
||||
|
||||
describe("Edge Cases", () => {
|
||||
test("handles empty chunks", () => {
|
||||
const result = processChunk("");
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe("");
|
||||
});
|
||||
|
||||
test("handles null chunks", () => {
|
||||
const result = processChunk(null);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
|
||||
test("handles undefined chunks", () => {
|
||||
const result = processChunk(undefined);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
|
||||
test("handles chunks with no text content", () => {
|
||||
const chunk = {
|
||||
event: {
|
||||
metadata: {
|
||||
timestamp: "2024-01-01",
|
||||
},
|
||||
},
|
||||
};
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(false);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
|
||||
test("handles malformed JSON in function calls gracefully", () => {
|
||||
const chunk =
|
||||
'{"type": "function", "name": "knowledge_search"} incomplete json';
|
||||
const result = processChunk(chunk);
|
||||
expect(result.isToolCall).toBe(true);
|
||||
expect(result.text).toBe(null);
|
||||
});
|
||||
});
|
||||
});
|
|
@ -31,6 +31,9 @@ const mockClient = {
|
|||
toolgroups: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
vectorDBs: {
|
||||
list: jest.fn(),
|
||||
},
|
||||
};
|
||||
|
||||
jest.mock("@/hooks/use-auth-client", () => ({
|
||||
|
@ -164,7 +167,7 @@ describe("ChatPlaygroundPage", () => {
|
|||
session_name: "Test Session",
|
||||
started_at: new Date().toISOString(),
|
||||
turns: [],
|
||||
}); // No turns by default
|
||||
});
|
||||
mockClient.agents.retrieve.mockResolvedValue({
|
||||
agent_id: "test-agent",
|
||||
agent_config: {
|
||||
|
@ -417,7 +420,6 @@ describe("ChatPlaygroundPage", () => {
|
|||
});
|
||||
|
||||
await waitFor(() => {
|
||||
// first agent should be auto-selected
|
||||
expect(mockClient.agents.session.create).toHaveBeenCalledWith(
|
||||
"agent_123",
|
||||
{ session_name: "Default Session" }
|
||||
|
@ -464,7 +466,7 @@ describe("ChatPlaygroundPage", () => {
|
|||
});
|
||||
});
|
||||
|
||||
test("hides delete button when only one agent exists", async () => {
|
||||
test("shows delete button even when only one agent exists", async () => {
|
||||
mockClient.agents.list.mockResolvedValue({
|
||||
data: [mockAgents[0]],
|
||||
});
|
||||
|
@ -474,9 +476,7 @@ describe("ChatPlaygroundPage", () => {
|
|||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.queryByTitle("Delete current agent")
|
||||
).not.toBeInTheDocument();
|
||||
expect(screen.getByTitle("Delete current agent")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
|
@ -505,7 +505,7 @@ describe("ChatPlaygroundPage", () => {
|
|||
await waitFor(() => {
|
||||
expect(mockClient.agents.delete).toHaveBeenCalledWith("agent_123");
|
||||
expect(global.confirm).toHaveBeenCalledWith(
|
||||
"Are you sure you want to delete this agent? This action cannot be undone and will delete all associated sessions."
|
||||
"Are you sure you want to delete this agent? This action cannot be undone and will delete the agent and all its sessions."
|
||||
);
|
||||
});
|
||||
|
||||
|
@ -584,4 +584,207 @@ describe("ChatPlaygroundPage", () => {
|
|||
consoleSpy.mockRestore();
|
||||
});
|
||||
});
|
||||
|
||||
describe("RAG File Upload", () => {
|
||||
let mockFileReader: {
|
||||
readAsDataURL: jest.Mock;
|
||||
readAsText: jest.Mock;
|
||||
result: string | null;
|
||||
onload: (() => void) | null;
|
||||
onerror: (() => void) | null;
|
||||
};
|
||||
let mockRAGTool: {
|
||||
insert: jest.Mock;
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
mockFileReader = {
|
||||
readAsDataURL: jest.fn(),
|
||||
readAsText: jest.fn(),
|
||||
result: null,
|
||||
onload: null,
|
||||
onerror: null,
|
||||
};
|
||||
global.FileReader = jest.fn(() => mockFileReader);
|
||||
|
||||
mockRAGTool = {
|
||||
insert: jest.fn().mockResolvedValue({}),
|
||||
};
|
||||
mockClient.toolRuntime = {
|
||||
ragTool: mockRAGTool,
|
||||
};
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks();
|
||||
});
|
||||
|
||||
test("handles text file upload", async () => {
|
||||
new File(["Hello, world!"], "test.txt", {
|
||||
type: "text/plain",
|
||||
});
|
||||
|
||||
mockClient.agents.retrieve.mockResolvedValue({
|
||||
agent_id: "test-agent",
|
||||
agent_config: {
|
||||
toolgroups: [
|
||||
{
|
||||
name: "builtin::rag/knowledge_search",
|
||||
args: { vector_db_ids: ["test-vector-db"] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("chat-component")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const chatComponent = screen.getByTestId("chat-component");
|
||||
chatComponent.getAttribute("data-onragfileupload");
|
||||
|
||||
// this is a simplified test
|
||||
expect(mockRAGTool.insert).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("handles PDF file upload with FileReader", async () => {
|
||||
new File([new ArrayBuffer(1000)], "test.pdf", {
|
||||
type: "application/pdf",
|
||||
});
|
||||
|
||||
const mockDataURL = "data:application/pdf;base64,JVBERi0xLjQK";
|
||||
mockFileReader.result = mockDataURL;
|
||||
|
||||
mockClient.agents.retrieve.mockResolvedValue({
|
||||
agent_id: "test-agent",
|
||||
agent_config: {
|
||||
toolgroups: [
|
||||
{
|
||||
name: "builtin::rag/knowledge_search",
|
||||
args: { vector_db_ids: ["test-vector-db"] },
|
||||
},
|
||||
],
|
||||
},
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
render(<ChatPlaygroundPage />);
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId("chat-component")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(global.FileReader).toBeDefined();
|
||||
});
|
||||
|
||||
test("handles different file types correctly", () => {
|
||||
const getContentType = (filename: string): string => {
|
||||
const ext = filename.toLowerCase().split(".").pop();
|
||||
switch (ext) {
|
||||
case "pdf":
|
||||
return "application/pdf";
|
||||
case "txt":
|
||||
return "text/plain";
|
||||
case "md":
|
||||
return "text/markdown";
|
||||
case "html":
|
||||
return "text/html";
|
||||
case "csv":
|
||||
return "text/csv";
|
||||
case "json":
|
||||
return "application/json";
|
||||
case "docx":
|
||||
return "application/vnd.openxmlformats-officedocument.wordprocessingml.document";
|
||||
case "doc":
|
||||
return "application/msword";
|
||||
default:
|
||||
return "application/octet-stream";
|
||||
}
|
||||
};
|
||||
|
||||
expect(getContentType("test.pdf")).toBe("application/pdf");
|
||||
expect(getContentType("test.txt")).toBe("text/plain");
|
||||
expect(getContentType("test.md")).toBe("text/markdown");
|
||||
expect(getContentType("test.html")).toBe("text/html");
|
||||
expect(getContentType("test.csv")).toBe("text/csv");
|
||||
expect(getContentType("test.json")).toBe("application/json");
|
||||
expect(getContentType("test.docx")).toBe(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
);
|
||||
expect(getContentType("test.doc")).toBe("application/msword");
|
||||
expect(getContentType("test.unknown")).toBe("application/octet-stream");
|
||||
});
|
||||
|
||||
test("determines text vs binary file types correctly", () => {
|
||||
const isTextFile = (mimeType: string): boolean => {
|
||||
return (
|
||||
mimeType.startsWith("text/") ||
|
||||
mimeType === "application/json" ||
|
||||
mimeType === "text/markdown" ||
|
||||
mimeType === "text/html" ||
|
||||
mimeType === "text/csv"
|
||||
);
|
||||
};
|
||||
|
||||
expect(isTextFile("text/plain")).toBe(true);
|
||||
expect(isTextFile("text/markdown")).toBe(true);
|
||||
expect(isTextFile("text/html")).toBe(true);
|
||||
expect(isTextFile("text/csv")).toBe(true);
|
||||
expect(isTextFile("application/json")).toBe(true);
|
||||
|
||||
expect(isTextFile("application/pdf")).toBe(false);
|
||||
expect(isTextFile("application/msword")).toBe(false);
|
||||
expect(
|
||||
isTextFile(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
)
|
||||
).toBe(false);
|
||||
expect(isTextFile("application/octet-stream")).toBe(false);
|
||||
});
|
||||
|
||||
test("handles FileReader error gracefully", async () => {
|
||||
const pdfFile = new File([new ArrayBuffer(1000)], "test.pdf", {
|
||||
type: "application/pdf",
|
||||
});
|
||||
|
||||
mockFileReader.onerror = jest.fn();
|
||||
const mockError = new Error("FileReader failed");
|
||||
|
||||
const fileReaderPromise = new Promise<string>((resolve, reject) => {
|
||||
const reader = new FileReader();
|
||||
reader.onload = () => resolve(reader.result as string);
|
||||
reader.onerror = () => reject(reader.error || mockError);
|
||||
reader.readAsDataURL(pdfFile);
|
||||
|
||||
setTimeout(() => {
|
||||
reader.onerror?.(new ProgressEvent("error"));
|
||||
}, 0);
|
||||
});
|
||||
|
||||
await expect(fileReaderPromise).rejects.toBeDefined();
|
||||
});
|
||||
|
||||
test("handles large file upload with FileReader approach", () => {
|
||||
// create a large file
|
||||
const largeFile = new File(
|
||||
[new ArrayBuffer(10 * 1024 * 1024)],
|
||||
"large.pdf",
|
||||
{
|
||||
type: "application/pdf",
|
||||
}
|
||||
);
|
||||
|
||||
expect(largeFile.size).toBe(10 * 1024 * 1024); // 10MB
|
||||
|
||||
expect(global.FileReader).toBeDefined();
|
||||
|
||||
const reader = new FileReader();
|
||||
expect(reader.readAsDataURL).toBeDefined();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -35,6 +35,7 @@ interface ChatPropsBase {
|
|||
) => void;
|
||||
setMessages?: (messages: Message[]) => void;
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||
onRAGFileUpload?: (file: File) => Promise<void>;
|
||||
}
|
||||
|
||||
interface ChatPropsWithoutSuggestions extends ChatPropsBase {
|
||||
|
@ -62,6 +63,7 @@ export function Chat({
|
|||
onRateResponse,
|
||||
setMessages,
|
||||
transcribeAudio,
|
||||
onRAGFileUpload,
|
||||
}: ChatProps) {
|
||||
const lastMessage = messages.at(-1);
|
||||
const isEmpty = messages.length === 0;
|
||||
|
@ -226,16 +228,17 @@ export function Chat({
|
|||
isPending={isGenerating || isTyping}
|
||||
handleSubmit={handleSubmit}
|
||||
>
|
||||
{({ files, setFiles }) => (
|
||||
{() => (
|
||||
<MessageInput
|
||||
value={input}
|
||||
onChange={handleInputChange}
|
||||
allowAttachments
|
||||
files={files}
|
||||
setFiles={setFiles}
|
||||
allowAttachments={true}
|
||||
files={null}
|
||||
setFiles={() => {}}
|
||||
stop={handleStop}
|
||||
isGenerating={isGenerating}
|
||||
transcribeAudio={transcribeAudio}
|
||||
onRAGFileUpload={onRAGFileUpload}
|
||||
/>
|
||||
)}
|
||||
</ChatForm>
|
||||
|
|
|
@ -14,6 +14,7 @@ import { Card } from "@/components/ui/card";
|
|||
import { Trash2 } from "lucide-react";
|
||||
import type { Message } from "@/components/chat-playground/chat-message";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import { cleanMessageContent } from "@/lib/message-content-utils";
|
||||
import type {
|
||||
Session,
|
||||
SessionCreateParams,
|
||||
|
@ -219,10 +220,7 @@ export function Conversations({
|
|||
messages.push({
|
||||
id: `${turn.turn_id}-assistant-${messages.length}`,
|
||||
role: "assistant",
|
||||
content:
|
||||
typeof turn.output_message.content === "string"
|
||||
? turn.output_message.content
|
||||
: JSON.stringify(turn.output_message.content),
|
||||
content: cleanMessageContent(turn.output_message.content),
|
||||
createdAt: new Date(
|
||||
turn.completed_at || turn.started_at || Date.now()
|
||||
),
|
||||
|
@ -271,7 +269,7 @@ export function Conversations({
|
|||
);
|
||||
|
||||
const deleteSession = async (sessionId: string) => {
|
||||
if (sessions.length <= 1 || !selectedAgentId) {
|
||||
if (!selectedAgentId) {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -324,7 +322,6 @@ export function Conversations({
|
|||
}
|
||||
}, [currentSession]);
|
||||
|
||||
// Don't render if no agent is selected
|
||||
if (!selectedAgentId) {
|
||||
return null;
|
||||
}
|
||||
|
@ -357,7 +354,7 @@ export function Conversations({
|
|||
+ New
|
||||
</Button>
|
||||
|
||||
{currentSession && sessions.length > 1 && (
|
||||
{currentSession && (
|
||||
<Button
|
||||
onClick={() => deleteSession(currentSession.id)}
|
||||
variant="outline"
|
||||
|
|
|
@ -21,6 +21,7 @@ interface MessageInputBaseProps
|
|||
isGenerating: boolean;
|
||||
enableInterrupt?: boolean;
|
||||
transcribeAudio?: (blob: Blob) => Promise<string>;
|
||||
onRAGFileUpload?: (file: File) => Promise<void>;
|
||||
}
|
||||
|
||||
interface MessageInputWithoutAttachmentProps extends MessageInputBaseProps {
|
||||
|
@ -213,8 +214,13 @@ export function MessageInput({
|
|||
className
|
||||
)}
|
||||
{...(props.allowAttachments
|
||||
? omit(props, ["allowAttachments", "files", "setFiles"])
|
||||
: omit(props, ["allowAttachments"]))}
|
||||
? omit(props, [
|
||||
"allowAttachments",
|
||||
"files",
|
||||
"setFiles",
|
||||
"onRAGFileUpload",
|
||||
])
|
||||
: omit(props, ["allowAttachments", "onRAGFileUpload"]))}
|
||||
/>
|
||||
|
||||
{props.allowAttachments && (
|
||||
|
@ -254,11 +260,19 @@ export function MessageInput({
|
|||
size="icon"
|
||||
variant="outline"
|
||||
className="h-8 w-8"
|
||||
aria-label="Attach a file"
|
||||
disabled={true}
|
||||
aria-label="Upload file to RAG"
|
||||
disabled={false}
|
||||
onClick={async () => {
|
||||
const files = await showFileUploadDialog();
|
||||
addFiles(files);
|
||||
const input = document.createElement("input");
|
||||
input.type = "file";
|
||||
input.accept = ".pdf,.txt,.md,.html,.csv,.json";
|
||||
input.onchange = async e => {
|
||||
const file = (e.target as HTMLInputElement).files?.[0];
|
||||
if (file && props.onRAGFileUpload) {
|
||||
await props.onRAGFileUpload(file);
|
||||
}
|
||||
};
|
||||
input.click();
|
||||
}}
|
||||
>
|
||||
<Paperclip className="h-4 w-4" />
|
||||
|
@ -337,28 +351,6 @@ function FileUploadOverlay({ isDragging }: FileUploadOverlayProps) {
|
|||
);
|
||||
}
|
||||
|
||||
function showFileUploadDialog() {
|
||||
const input = document.createElement("input");
|
||||
|
||||
input.type = "file";
|
||||
input.multiple = true;
|
||||
input.accept = "*/*";
|
||||
input.click();
|
||||
|
||||
return new Promise<File[] | null>(resolve => {
|
||||
input.onchange = e => {
|
||||
const files = (e.currentTarget as HTMLInputElement).files;
|
||||
|
||||
if (files) {
|
||||
resolve(Array.from(files));
|
||||
return;
|
||||
}
|
||||
|
||||
resolve(null);
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
function TranscribingOverlay() {
|
||||
return (
|
||||
<motion.div
|
||||
|
|
243
llama_stack/ui/components/chat-playground/vector-db-creator.tsx
Normal file
243
llama_stack/ui/components/chat-playground/vector-db-creator.tsx
Normal file
|
@ -0,0 +1,243 @@
|
|||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Card } from "@/components/ui/card";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useAuthClient } from "@/hooks/use-auth-client";
|
||||
import type { Model } from "llama-stack-client/resources/models";
|
||||
|
||||
interface VectorDBCreatorProps {
|
||||
models: Model[];
|
||||
onVectorDBCreated?: (vectorDbId: string) => void;
|
||||
onCancel?: () => void;
|
||||
}
|
||||
|
||||
interface VectorDBProvider {
|
||||
api: string;
|
||||
provider_id: string;
|
||||
provider_type: string;
|
||||
}
|
||||
|
||||
export function VectorDBCreator({
|
||||
models,
|
||||
onVectorDBCreated,
|
||||
onCancel,
|
||||
}: VectorDBCreatorProps) {
|
||||
const [vectorDbName, setVectorDbName] = useState("");
|
||||
const [selectedEmbeddingModel, setSelectedEmbeddingModel] = useState("");
|
||||
const [selectedProvider, setSelectedProvider] = useState("faiss");
|
||||
const [availableProviders, setAvailableProviders] = useState<
|
||||
VectorDBProvider[]
|
||||
>([]);
|
||||
const [isCreating, setIsCreating] = useState(false);
|
||||
const [isLoadingProviders, setIsLoadingProviders] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const client = useAuthClient();
|
||||
|
||||
const embeddingModels = models.filter(
|
||||
model => model.model_type === "embedding"
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const fetchProviders = async () => {
|
||||
setIsLoadingProviders(true);
|
||||
try {
|
||||
const providersResponse = await client.providers.list();
|
||||
|
||||
const vectorIoProviders = providersResponse.filter(
|
||||
(provider: VectorDBProvider) => provider.api === "vector_io"
|
||||
);
|
||||
|
||||
setAvailableProviders(vectorIoProviders);
|
||||
|
||||
if (vectorIoProviders.length > 0) {
|
||||
const faissProvider = vectorIoProviders.find(
|
||||
(p: VectorDBProvider) => p.provider_id === "faiss"
|
||||
);
|
||||
setSelectedProvider(
|
||||
faissProvider?.provider_id || vectorIoProviders[0].provider_id
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Error fetching providers:", err);
|
||||
setAvailableProviders([
|
||||
{
|
||||
api: "vector_io",
|
||||
provider_id: "faiss",
|
||||
provider_type: "inline::faiss",
|
||||
},
|
||||
]);
|
||||
} finally {
|
||||
setIsLoadingProviders(false);
|
||||
}
|
||||
};
|
||||
|
||||
fetchProviders();
|
||||
}, [client]);
|
||||
|
||||
const handleCreate = async () => {
|
||||
if (!vectorDbName.trim() || !selectedEmbeddingModel) {
|
||||
setError("Please provide a name and select an embedding model");
|
||||
return;
|
||||
}
|
||||
|
||||
setIsCreating(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const embeddingModel = embeddingModels.find(
|
||||
m => m.identifier === selectedEmbeddingModel
|
||||
);
|
||||
|
||||
if (!embeddingModel) {
|
||||
throw new Error("Selected embedding model not found");
|
||||
}
|
||||
|
||||
const embeddingDimension = embeddingModel.metadata
|
||||
?.embedding_dimension as number;
|
||||
|
||||
if (!embeddingDimension) {
|
||||
throw new Error("Embedding dimension not available for selected model");
|
||||
}
|
||||
|
||||
const vectorDbId = vectorDbName.trim() || `vector_db_${Date.now()}`;
|
||||
|
||||
const response = await client.vectorDBs.register({
|
||||
vector_db_id: vectorDbId,
|
||||
embedding_model: selectedEmbeddingModel,
|
||||
embedding_dimension: embeddingDimension,
|
||||
provider_id: selectedProvider,
|
||||
});
|
||||
|
||||
onVectorDBCreated?.(response.identifier || vectorDbId);
|
||||
} catch (err) {
|
||||
console.error("Error creating vector DB:", err);
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to create vector DB"
|
||||
);
|
||||
} finally {
|
||||
setIsCreating(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Card className="p-6 space-y-4">
|
||||
<h3 className="text-lg font-semibold">Create Vector Database</h3>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label className="text-sm font-medium block mb-2">
|
||||
Vector DB Name
|
||||
</label>
|
||||
<Input
|
||||
value={vectorDbName}
|
||||
onChange={e => setVectorDbName(e.target.value)}
|
||||
placeholder="My Vector Database"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="text-sm font-medium block mb-2">
|
||||
Embedding Model
|
||||
</label>
|
||||
<Select
|
||||
value={selectedEmbeddingModel}
|
||||
onValueChange={setSelectedEmbeddingModel}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue placeholder="Select Embedding Model" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{embeddingModels.map(model => (
|
||||
<SelectItem key={model.identifier} value={model.identifier}>
|
||||
{model.identifier}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{selectedEmbeddingModel && (
|
||||
<p className="text-xs text-muted-foreground mt-1">
|
||||
Dimension:{" "}
|
||||
{embeddingModels.find(
|
||||
m => m.identifier === selectedEmbeddingModel
|
||||
)?.metadata?.embedding_dimension || "Unknown"}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="text-sm font-medium block mb-2">
|
||||
Vector Database Provider
|
||||
</label>
|
||||
<Select
|
||||
value={selectedProvider}
|
||||
onValueChange={setSelectedProvider}
|
||||
disabled={isLoadingProviders}
|
||||
>
|
||||
<SelectTrigger>
|
||||
<SelectValue
|
||||
placeholder={
|
||||
isLoadingProviders
|
||||
? "Loading providers..."
|
||||
: "Select Provider"
|
||||
}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{availableProviders.map(provider => (
|
||||
<SelectItem
|
||||
key={provider.provider_id}
|
||||
value={provider.provider_id}
|
||||
>
|
||||
{provider.provider_id}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
{selectedProvider && (
|
||||
<p className="text-xs text-muted-foreground mt-1">
|
||||
Selected provider: {selectedProvider}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="text-destructive text-sm bg-destructive/10 p-2 rounded">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex gap-2 pt-2">
|
||||
<Button
|
||||
onClick={handleCreate}
|
||||
disabled={
|
||||
isCreating || !vectorDbName.trim() || !selectedEmbeddingModel
|
||||
}
|
||||
className="flex-1"
|
||||
>
|
||||
{isCreating ? "Creating..." : "Create Vector DB"}
|
||||
</Button>
|
||||
{onCancel && (
|
||||
<Button variant="outline" onClick={onCancel} className="flex-1">
|
||||
Cancel
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="text-xs text-muted-foreground bg-muted/50 p-3 rounded">
|
||||
<strong>Note:</strong> This will create a new vector database that can
|
||||
be used with RAG tools. After creation, you'll be able to upload
|
||||
documents and use it for knowledge search in your agent conversations.
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
51
llama_stack/ui/lib/message-content-utils.ts
Normal file
51
llama_stack/ui/lib/message-content-utils.ts
Normal file
|
@ -0,0 +1,51 @@
|
|||
// check if content contains function call JSON
|
||||
export const containsToolCall = (content: string): boolean => {
|
||||
return (
|
||||
content.includes('"type": "function"') ||
|
||||
content.includes('"name": "knowledge_search"') ||
|
||||
content.includes('"parameters":') ||
|
||||
!!content.match(/\{"type":\s*"function".*?\}/)
|
||||
);
|
||||
};
|
||||
|
||||
export const extractCleanText = (content: string): string | null => {
|
||||
if (containsToolCall(content)) {
|
||||
try {
|
||||
// parse and extract non-function call parts
|
||||
const jsonMatch = content.match(/\{"type":\s*"function"[^}]*\}[^}]*\}/);
|
||||
if (jsonMatch) {
|
||||
const jsonPart = jsonMatch[0];
|
||||
const parsedJson = JSON.parse(jsonPart);
|
||||
|
||||
// if function call, extract text after JSON
|
||||
if (parsedJson.type === "function") {
|
||||
const textAfterJson = content
|
||||
.substring(content.indexOf(jsonPart) + jsonPart.length)
|
||||
.trim();
|
||||
return textAfterJson || null;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
return content;
|
||||
};
|
||||
|
||||
// removes function call JSON handling different content types
|
||||
export const cleanMessageContent = (
|
||||
content: string | unknown[] | unknown
|
||||
): string => {
|
||||
if (typeof content === "string") {
|
||||
const cleaned = extractCleanText(content);
|
||||
return cleaned || "";
|
||||
} else if (Array.isArray(content)) {
|
||||
return content
|
||||
.filter((item: { type: string }) => item.type === "text")
|
||||
.map((item: { text: string }) => item.text)
|
||||
.join("");
|
||||
} else {
|
||||
return JSON.stringify(content);
|
||||
}
|
||||
};
|
337
llama_stack/ui/package-lock.json
generated
337
llama_stack/ui/package-lock.json
generated
|
@ -14,11 +14,11 @@
|
|||
"@radix-ui/react-select": "^2.2.5",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": "^0.2.18",
|
||||
"framer-motion": "^12.23.12",
|
||||
"llama-stack-client": "^0.2.20",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
@ -36,19 +36,19 @@
|
|||
"@eslint/eslintrc": "^3",
|
||||
"@tailwindcss/postcss": "^4",
|
||||
"@testing-library/dom": "^10.4.1",
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
"@testing-library/jest-dom": "^6.8.0",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@types/jest": "^29.5.14",
|
||||
"@types/node": "^20",
|
||||
"@types/node": "^24",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "15.3.2",
|
||||
"eslint-config-next": "15.5.2",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-prettier": "^5.5.4",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0",
|
||||
"prettier": "3.5.3",
|
||||
"prettier": "3.6.2",
|
||||
"tailwindcss": "^4",
|
||||
"ts-node": "^10.9.2",
|
||||
"tw-animate-css": "^1.2.9",
|
||||
|
@ -1854,9 +1854,9 @@
|
|||
"integrity": "sha512-OdiMrzCl2Xi0VTjiQQUK0Xh7bJHnOuET2s+3V+Y40WJBAXrJeGA3f+I8MZJ/YQ3mVGi5XGR1L66oFlgqXhQ4Vw=="
|
||||
},
|
||||
"node_modules/@next/eslint-plugin-next": {
|
||||
"version": "15.3.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/eslint-plugin-next/-/eslint-plugin-next-15.3.2.tgz",
|
||||
"integrity": "sha512-ijVRTXBgnHT33aWnDtmlG+LJD+5vhc9AKTJPquGG5NKXjpKNjc62woIhFtrAcWdBobt8kqjCoaJ0q6sDQoX7aQ==",
|
||||
"version": "15.5.2",
|
||||
"resolved": "https://registry.npmjs.org/@next/eslint-plugin-next/-/eslint-plugin-next-15.5.2.tgz",
|
||||
"integrity": "sha512-lkLrRVxcftuOsJNhWatf1P2hNVfh98k/omQHrCEPPriUypR6RcS13IvLdIrEvkm9AH2Nu2YpR5vLqBuy6twH3Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
|
@ -2861,29 +2861,6 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-visually-hidden": {
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz",
|
||||
"integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-separator": {
|
||||
"version": "1.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-separator/-/react-separator-1.1.7.tgz",
|
||||
|
@ -2949,23 +2926,23 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip": {
|
||||
"version": "1.2.6",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.6.tgz",
|
||||
"integrity": "sha512-zYb+9dc9tkoN2JjBDIIPLQtk3gGyz8FMKoqYTb8EMVQ5a5hBcdHPECrsZVI4NpPAUOixhkoqg7Hj5ry5USowfA==",
|
||||
"version": "1.2.8",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.2.8.tgz",
|
||||
"integrity": "sha512-tY7sVt1yL9ozIxvmbtN5qtmH2krXcBCfjEiCgKGLqunJHvgvZG2Pcl2oQ3kbcZARb1BGEHdkLzcYGO8ynVlieg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.2",
|
||||
"@radix-ui/primitive": "1.1.3",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-dismissable-layer": "1.1.9",
|
||||
"@radix-ui/react-dismissable-layer": "1.1.11",
|
||||
"@radix-ui/react-id": "1.1.1",
|
||||
"@radix-ui/react-popper": "1.2.6",
|
||||
"@radix-ui/react-portal": "1.1.8",
|
||||
"@radix-ui/react-presence": "1.1.4",
|
||||
"@radix-ui/react-primitive": "2.1.2",
|
||||
"@radix-ui/react-slot": "1.2.2",
|
||||
"@radix-ui/react-popper": "1.2.8",
|
||||
"@radix-ui/react-portal": "1.1.9",
|
||||
"@radix-ui/react-presence": "1.1.5",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-slot": "1.2.3",
|
||||
"@radix-ui/react-use-controllable-state": "1.2.2",
|
||||
"@radix-ui/react-visually-hidden": "1.2.2"
|
||||
"@radix-ui/react-visually-hidden": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
|
@ -2982,21 +2959,162 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.2.2.tgz",
|
||||
"integrity": "sha512-y7TBO4xN4Y94FvcWIOIh18fM4R1A8S4q1jhoz4PNzOoHsFcN8pogcFmZrTYAm4F9VRUrWP/Mw7xSKybIeRI+CQ==",
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/primitive": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.3.tgz",
|
||||
"integrity": "sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-arrow": {
|
||||
"version": "1.1.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-arrow/-/react-arrow-1.1.7.tgz",
|
||||
"integrity": "sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2"
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-dismissable-layer": {
|
||||
"version": "1.1.11",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.1.11.tgz",
|
||||
"integrity": "sha512-Nqcp+t5cTB8BinFkZgXiMJniQH0PsUt2k51FUhbdfeKvc4ACcG2uQniY/8+h1Yv6Kza4Q7lD7PQV0z0oicE0Mg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.3",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.1.1",
|
||||
"@radix-ui/react-use-escape-keydown": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-popper": {
|
||||
"version": "1.2.8",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.2.8.tgz",
|
||||
"integrity": "sha512-0NJQ4LFFUuWkE7Oxf0htBKS6zLkkjBH+hM1uk7Ng705ReR8m/uelduy1DBo0PyBXPKVnBA6YBlU94MBGXrSBCw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/react-dom": "^2.0.0",
|
||||
"@radix-ui/react-arrow": "1.1.7",
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-context": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-callback-ref": "1.1.1",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.1",
|
||||
"@radix-ui/react-use-rect": "1.1.1",
|
||||
"@radix-ui/react-use-size": "1.1.1",
|
||||
"@radix-ui/rect": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-portal": {
|
||||
"version": "1.1.9",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-portal/-/react-portal-1.1.9.tgz",
|
||||
"integrity": "sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.3",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-presence": {
|
||||
"version": "1.1.5",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.1.5.tgz",
|
||||
"integrity": "sha512-/jfEwNDdQVBCNvjkGit4h6pMOzq8bHkopq458dPt2lMjx+eBQUohZNG9A7DtO/O5ukSbxuaNGXMjHicgwy6rQQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.2",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz",
|
||||
"integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
|
@ -3137,12 +3255,35 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-visually-hidden": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.2.tgz",
|
||||
"integrity": "sha512-ORCmRUbNiZIv6uV5mhFrhsIKw4UX/N3syZtyqvry61tbGm4JlgQuSn0hk5TwCARsCjkcnuRkSdCE3xfb+ADHew==",
|
||||
"version": "1.2.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-visually-hidden/-/react-visually-hidden-1.2.3.tgz",
|
||||
"integrity": "sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-primitive": "2.1.2"
|
||||
"@radix-ui/react-primitive": "2.1.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-visually-hidden/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.1.3.tgz",
|
||||
"integrity": "sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.2.3"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
|
@ -3597,18 +3738,17 @@
|
|||
}
|
||||
},
|
||||
"node_modules/@testing-library/jest-dom": {
|
||||
"version": "6.6.3",
|
||||
"resolved": "https://registry.npmjs.org/@testing-library/jest-dom/-/jest-dom-6.6.3.tgz",
|
||||
"integrity": "sha512-IteBhl4XqYNkM54f4ejhLRJiZNqcSCoXUOG2CPK7qbD322KjQozM4kHQOfkG2oln9b9HTYqs+Sae8vBATubxxA==",
|
||||
"version": "6.8.0",
|
||||
"resolved": "https://registry.npmjs.org/@testing-library/jest-dom/-/jest-dom-6.8.0.tgz",
|
||||
"integrity": "sha512-WgXcWzVM6idy5JaftTVC8Vs83NKRmGJz4Hqs4oyOuO2J4r/y79vvKZsb+CaGyCSEbUPI6OsewfPd0G1A0/TUZQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@adobe/css-tools": "^4.4.0",
|
||||
"aria-query": "^5.0.0",
|
||||
"chalk": "^3.0.0",
|
||||
"css.escape": "^1.5.1",
|
||||
"dom-accessibility-api": "^0.6.3",
|
||||
"lodash": "^4.17.21",
|
||||
"picocolors": "^1.1.1",
|
||||
"redent": "^3.0.0"
|
||||
},
|
||||
"engines": {
|
||||
|
@ -3617,20 +3757,6 @@
|
|||
"yarn": ">=1"
|
||||
}
|
||||
},
|
||||
"node_modules/@testing-library/jest-dom/node_modules/chalk": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/chalk/-/chalk-3.0.0.tgz",
|
||||
"integrity": "sha512-4D3B6Wf41KOYRFdszmDqMCGq5VV/uMAB273JILmO+3jAlh8X4qDtdtgCR3fxtbLEMzSx22QdhnDcJvu2u1fVwg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"ansi-styles": "^4.1.0",
|
||||
"supports-color": "^7.1.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=8"
|
||||
}
|
||||
},
|
||||
"node_modules/@testing-library/jest-dom/node_modules/dom-accessibility-api": {
|
||||
"version": "0.6.3",
|
||||
"resolved": "https://registry.npmjs.org/dom-accessibility-api/-/dom-accessibility-api-0.6.3.tgz",
|
||||
|
@ -3925,12 +4051,12 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/node": {
|
||||
"version": "20.17.47",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.17.47.tgz",
|
||||
"integrity": "sha512-3dLX0Upo1v7RvUimvxLeXqwrfyKxUINk0EAM83swP2mlSUcwV73sZy8XhNz8bcZ3VbsfQyC/y6jRdL5tgCNpDQ==",
|
||||
"version": "24.3.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-24.3.0.tgz",
|
||||
"integrity": "sha512-aPTXCrfwnDLj4VvXrm+UUCQjNEvJgNA8s5F1cvwQU+3KNltTOkBm1j30uNLyqqPNe7gE3KFzImYoZEfLhp4Yow==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~6.19.2"
|
||||
"undici-types": "~7.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/node-fetch": {
|
||||
|
@ -6448,13 +6574,13 @@
|
|||
}
|
||||
},
|
||||
"node_modules/eslint-config-next": {
|
||||
"version": "15.3.2",
|
||||
"resolved": "https://registry.npmjs.org/eslint-config-next/-/eslint-config-next-15.3.2.tgz",
|
||||
"integrity": "sha512-FerU4DYccO4FgeYFFglz0SnaKRe1ejXQrDb8kWUkTAg036YWi+jUsgg4sIGNCDhAsDITsZaL4MzBWKB6f4G1Dg==",
|
||||
"version": "15.5.2",
|
||||
"resolved": "https://registry.npmjs.org/eslint-config-next/-/eslint-config-next-15.5.2.tgz",
|
||||
"integrity": "sha512-3hPZghsLupMxxZ2ggjIIrat/bPniM2yRpsVPVM40rp8ZMzKWOJp2CGWn7+EzoV2ddkUr5fxNfHpF+wU1hGt/3g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@next/eslint-plugin-next": "15.3.2",
|
||||
"@next/eslint-plugin-next": "15.5.2",
|
||||
"@rushstack/eslint-patch": "^1.10.3",
|
||||
"@typescript-eslint/eslint-plugin": "^5.4.2 || ^6.0.0 || ^7.0.0 || ^8.0.0",
|
||||
"@typescript-eslint/parser": "^5.4.2 || ^6.0.0 || ^7.0.0 || ^8.0.0",
|
||||
|
@ -7283,13 +7409,13 @@
|
|||
}
|
||||
},
|
||||
"node_modules/framer-motion": {
|
||||
"version": "11.18.2",
|
||||
"resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-11.18.2.tgz",
|
||||
"integrity": "sha512-5F5Och7wrvtLVElIpclDT0CBzMVg3dL22B64aZwHtsIY8RB4mXICLrkajK4G9R+ieSAGcgrLeae2SeUTg2pr6w==",
|
||||
"version": "12.23.12",
|
||||
"resolved": "https://registry.npmjs.org/framer-motion/-/framer-motion-12.23.12.tgz",
|
||||
"integrity": "sha512-6e78rdVtnBvlEVgu6eFEAgG9v3wLnYEboM8I5O5EXvfKC8gxGQB8wXJdhkMy10iVcn05jl6CNw7/HTsTCfwcWg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"motion-dom": "^11.18.1",
|
||||
"motion-utils": "^11.18.1",
|
||||
"motion-dom": "^12.23.12",
|
||||
"motion-utils": "^12.23.6",
|
||||
"tslib": "^2.4.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
|
@ -10021,9 +10147,9 @@
|
|||
"license": "MIT"
|
||||
},
|
||||
"node_modules/llama-stack-client": {
|
||||
"version": "0.2.18",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.18.tgz",
|
||||
"integrity": "sha512-k+xQOz/TIU0cINP4Aih8q6xs7f/6qs0fLDMXTTKQr5C0F1jtCjRiwsas7bTsDfpKfYhg/7Xy/wPw/uZgi6aIVg==",
|
||||
"version": "0.2.20",
|
||||
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.20.tgz",
|
||||
"integrity": "sha512-1vD5nizTX5JEW8TADxKgy/P1W8YZoPSpdnmfxbdYbWgpQ3BWtbvLS6jmDk7VwVA5fRC4895VfHsRDfS1liHarw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/node": "^18.11.18",
|
||||
|
@ -10066,13 +10192,6 @@
|
|||
"url": "https://github.com/sponsors/sindresorhus"
|
||||
}
|
||||
},
|
||||
"node_modules/lodash": {
|
||||
"version": "4.17.21",
|
||||
"resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz",
|
||||
"integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/lodash.merge": {
|
||||
"version": "4.6.2",
|
||||
"resolved": "https://registry.npmjs.org/lodash.merge/-/lodash.merge-4.6.2.tgz",
|
||||
|
@ -11206,18 +11325,18 @@
|
|||
}
|
||||
},
|
||||
"node_modules/motion-dom": {
|
||||
"version": "11.18.1",
|
||||
"resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-11.18.1.tgz",
|
||||
"integrity": "sha512-g76KvA001z+atjfxczdRtw/RXOM3OMSdd1f4DL77qCTF/+avrRJiawSG4yDibEQ215sr9kpinSlX2pCTJ9zbhw==",
|
||||
"version": "12.23.12",
|
||||
"resolved": "https://registry.npmjs.org/motion-dom/-/motion-dom-12.23.12.tgz",
|
||||
"integrity": "sha512-RcR4fvMCTESQBD/uKQe49D5RUeDOokkGRmz4ceaJKDBgHYtZtntC/s2vLvY38gqGaytinij/yi3hMcWVcEF5Kw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"motion-utils": "^11.18.1"
|
||||
"motion-utils": "^12.23.6"
|
||||
}
|
||||
},
|
||||
"node_modules/motion-utils": {
|
||||
"version": "11.18.1",
|
||||
"resolved": "https://registry.npmjs.org/motion-utils/-/motion-utils-11.18.1.tgz",
|
||||
"integrity": "sha512-49Kt+HKjtbJKLtgO/LKj9Ld+6vw9BjH5d9sc40R/kVyH8GLAXgT42M2NnuPcJNuA3s9ZfZBUcwIgpmZWGEE+hA==",
|
||||
"version": "12.23.6",
|
||||
"resolved": "https://registry.npmjs.org/motion-utils/-/motion-utils-12.23.6.tgz",
|
||||
"integrity": "sha512-eAWoPgr4eFEOFfg2WjIsMoqJTW6Z8MTUCgn/GZ3VRpClWBdnbjryiA3ZSNLyxCTmCQx4RmYX6jX1iWHbenUPNQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/ms": {
|
||||
|
@ -12105,9 +12224,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/prettier": {
|
||||
"version": "3.5.3",
|
||||
"resolved": "https://registry.npmjs.org/prettier/-/prettier-3.5.3.tgz",
|
||||
"integrity": "sha512-QQtaxnoDJeAkDvDKWCLiwIXkTgRhwYDEQCghU9Z6q03iyek/rxRh/2lC3HB7P8sWT2xC/y5JDctPLBIGzHKbhw==",
|
||||
"version": "3.6.2",
|
||||
"resolved": "https://registry.npmjs.org/prettier/-/prettier-3.6.2.tgz",
|
||||
"integrity": "sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"bin": {
|
||||
|
@ -14008,9 +14127,9 @@
|
|||
}
|
||||
},
|
||||
"node_modules/undici-types": {
|
||||
"version": "6.19.8",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.19.8.tgz",
|
||||
"integrity": "sha512-ve2KP6f/JnbPBFyobGHuerC9g1FYGn/F8n1LWTwNxCEzd6IfqTwUQcNXgEtmmQ6DlRrC1hrSrBnCZPokRrDHjw==",
|
||||
"version": "7.10.0",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.10.0.tgz",
|
||||
"integrity": "sha512-t5Fy/nfn+14LuOc2KNYg75vZqClpAiqscVvMygNnlsHBFpSXdJaYtXMcdNLpl/Qvc3P2cB3s6lOV51nqsFq4ag==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/unified": {
|
||||
|
|
|
@ -19,11 +19,11 @@
|
|||
"@radix-ui/react-select": "^2.2.5",
|
||||
"@radix-ui/react-separator": "^1.1.7",
|
||||
"@radix-ui/react-slot": "^1.2.3",
|
||||
"@radix-ui/react-tooltip": "^1.2.6",
|
||||
"@radix-ui/react-tooltip": "^1.2.8",
|
||||
"class-variance-authority": "^0.7.1",
|
||||
"clsx": "^2.1.1",
|
||||
"framer-motion": "^11.18.2",
|
||||
"llama-stack-client": "^0.2.18",
|
||||
"framer-motion": "^12.23.12",
|
||||
"llama-stack-client": "^0.2.20",
|
||||
"lucide-react": "^0.510.0",
|
||||
"next": "15.3.3",
|
||||
"next-auth": "^4.24.11",
|
||||
|
@ -41,19 +41,19 @@
|
|||
"@eslint/eslintrc": "^3",
|
||||
"@tailwindcss/postcss": "^4",
|
||||
"@testing-library/dom": "^10.4.1",
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
"@testing-library/jest-dom": "^6.8.0",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
"@types/jest": "^29.5.14",
|
||||
"@types/node": "^20",
|
||||
"@types/node": "^24",
|
||||
"@types/react": "^19",
|
||||
"@types/react-dom": "^19",
|
||||
"eslint": "^9",
|
||||
"eslint-config-next": "15.3.2",
|
||||
"eslint-config-next": "15.5.2",
|
||||
"eslint-config-prettier": "^10.1.8",
|
||||
"eslint-plugin-prettier": "^5.5.4",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0",
|
||||
"prettier": "3.5.3",
|
||||
"prettier": "3.6.2",
|
||||
"tailwindcss": "^4",
|
||||
"ts-node": "^10.9.2",
|
||||
"tw-animate-css": "^1.2.9",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue