Merge remote-tracking branch 'origin/main' into support_3pt1_8b

This commit is contained in:
Botao Chen 2025-01-02 15:22:25 -08:00
commit 8cdc91e619
52 changed files with 673 additions and 226 deletions

View file

@ -544,7 +544,7 @@
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" inference:\n", " inference:\n",
" - config:\n", " - config:\n",
" api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n", " api_key: <...>\n",
" url: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://api.together.xyz/v1</span>\n", " url: <span style=\"color: #0000ff; text-decoration-color: #0000ff; text-decoration: underline\">https://api.together.xyz/v1</span>\n",
" provider_id: together\n", " provider_id: together\n",
" provider_type: remote::together\n", " provider_type: remote::together\n",
@ -663,7 +663,7 @@
" provider_type: inline::meta-reference\n", " provider_type: inline::meta-reference\n",
" inference:\n", " inference:\n",
" - config:\n", " - config:\n",
" api_key: 4985b03e627419b2964d34b8519ac6c4319f094d1ffb4f45514b4eb87e5427a2\n", " api_key: <...>\n",
" url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n", " url: \u001b[4;94mhttps://api.together.xyz/v1\u001b[0m\n",
" provider_id: together\n", " provider_id: together\n",
" provider_type: remote::together\n", " provider_type: remote::together\n",

View file

@ -338,8 +338,8 @@ distribution_spec:
inference: remote::ollama inference: remote::ollama
memory: inline::faiss memory: inline::faiss
safety: inline::llama-guard safety: inline::llama-guard
agents: meta-reference agents: inline::meta-reference
telemetry: meta-reference telemetry: inline::meta-reference
image_type: conda image_type: conda
``` ```

View file

@ -358,7 +358,7 @@
" if not stream:\n", " if not stream:\n",
" cprint(f'> Response: {response.completion_message.content}', 'cyan')\n", " cprint(f'> Response: {response.completion_message.content}', 'cyan')\n",
" else:\n", " else:\n",
" async for log in EventLogger().log(response):\n", " for log in EventLogger().log(response):\n",
" log.print()\n", " log.print()\n",
"\n", "\n",
"# In a Jupyter Notebook cell, use `await` to call the function\n", "# In a Jupyter Notebook cell, use `await` to call the function\n",
@ -366,16 +366,6 @@
"# To run it in a python file, use this line instead\n", "# To run it in a python file, use this line instead\n",
"# asyncio.run(run_main())\n" "# asyncio.run(run_main())\n"
] ]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "9399aecc",
"metadata": {},
"outputs": [],
"source": [
"#fin"
]
} }
], ],
"metadata": { "metadata": {

View file

@ -45,7 +45,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
--- ---
## Install Dependencies and Set Up Environment ## Install Dependencies and Set Up Environmen
1. **Create a Conda Environment**: 1. **Create a Conda Environment**:
Create a new Conda environment with Python 3.10: Create a new Conda environment with Python 3.10:
@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
Open a new terminal and install `llama-stack`: Open a new terminal and install `llama-stack`:
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack==0.0.55 pip install llama-stack==0.0.61
``` ```
--- ---
@ -96,7 +96,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Set the ENV variables by exporting them to the terminal**: 3. **Set the ENV variables by exporting them to the terminal**:
```bash ```bash
export OLLAMA_URL="http://localhost:11434" export OLLAMA_URL="http://localhost:11434"
export LLAMA_STACK_PORT=5051 export LLAMA_STACK_PORT=5001
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B" export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
``` ```
@ -104,34 +104,29 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
3. **Run the Llama Stack**: 3. **Run the Llama Stack**:
Run the stack with command shared by the API from earlier: Run the stack with command shared by the API from earlier:
```bash ```bash
llama stack run ollama \ llama stack run ollama
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL
--env SAFETY_MODEL=$SAFETY_MODEL \ --env SAFETY_MODEL=$SAFETY_MODEL
--env OLLAMA_URL=$OLLAMA_URL --env OLLAMA_URL=$OLLAMA_URL
``` ```
Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model. Note: Everytime you run a new model with `ollama run`, you will need to restart the llama stack. Otherwise it won't see the new model.
The server will start and listen on `http://localhost:5051`. The server will start and listen on `http://localhost:5001`.
--- ---
## Test with `llama-stack-client` CLI ## Test with `llama-stack-client` CLI
After setting up the server, open a new terminal window and install the llama-stack-client package. After setting up the server, open a new terminal window and configure the llama-stack-client.
1. Install the llama-stack-client package 1. Configure the CLI to point to the llama-stack server.
```bash ```bash
conda activate ollama llama-stack-client configure --endpoint http://localhost:5001
pip install llama-stack-client
```
2. Configure the CLI to point to the llama-stack server.
```bash
llama-stack-client configure --endpoint http://localhost:5051
``` ```
**Expected Output:** **Expected Output:**
```bash ```bash
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5051 Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:5001
``` ```
3. Test the CLI by running inference: 2. Test the CLI by running inference:
```bash ```bash
llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon" llama-stack-client inference chat-completion --message "Write me a 2-sentence poem about the moon"
``` ```
@ -153,16 +148,18 @@ After setting up the server, open a new terminal window and install the llama-st
After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`: After setting up the server, open a new terminal window and verify it's working by sending a `POST` request using `curl`:
```bash ```bash
curl http://localhost:$LLAMA_STACK_PORT/inference/chat_completion \ curl http://localhost:$LLAMA_STACK_PORT/alpha/inference/chat-completion
-H "Content-Type: application/json" \ -H "Content-Type: application/json"
-d '{ -d @- <<EOF
"model": "Llama3.2-3B-Instruct", {
"model_id": "$INFERENCE_MODEL",
"messages": [ "messages": [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2-sentence poem about the moon"} {"role": "user", "content": "Write me a 2-sentence poem about the moon"}
], ],
"sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512} "sampling_params": {"temperature": 0.7, "seed": 42, "max_tokens": 512}
}' }
EOF
``` ```
You can check the available models with the command `llama-stack-client models list`. You can check the available models with the command `llama-stack-client models list`.
@ -186,16 +183,12 @@ You can check the available models with the command `llama-stack-client models l
You can also interact with the Llama Stack server using a simple Python script. Below is an example: You can also interact with the Llama Stack server using a simple Python script. Below is an example:
### 1. Activate Conda Environment and Install Required Python Packages ### 1. Activate Conda Environmen
The `llama-stack-client` library offers a robust and efficient python methods for interacting with the Llama Stack server.
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack-client
``` ```
Note, the client library gets installed by default if you install the server library
### 2. Create Python Script (`test_llama_stack.py`) ### 2. Create Python Script (`test_llama_stack.py`)
```bash ```bash
touch test_llama_stack.py touch test_llama_stack.py
@ -206,19 +199,28 @@ touch test_llama_stack.py
In `test_llama_stack.py`, write the following code: In `test_llama_stack.py`, write the following code:
```python ```python
from llama_stack_client import LlamaStackClient import os
from llama_stack_client import LlamaStackClien
# Initialize the client # Get the model ID from the environment variable
client = LlamaStackClient(base_url="http://localhost:5051") INFERENCE_MODEL = os.environ.get("INFERENCE_MODEL")
# Create a chat completion request # Check if the environment variable is se
if INFERENCE_MODEL is None:
raise ValueError("The environment variable 'INFERENCE_MODEL' is not set.")
# Initialize the clien
client = LlamaStackClient(base_url="http://localhost:5001")
# Create a chat completion reques
response = client.inference.chat_completion( response = client.inference.chat_completion(
messages=[ messages=[
{"role": "system", "content": "You are a friendly assistant."}, {"role": "system", "content": "You are a friendly assistant."},
{"role": "user", "content": "Write a two-sentence poem about llama."} {"role": "user", "content": "Write a two-sentence poem about llama."}
], ],
model_id=MODEL_NAME, model_id=INFERENCE_MODEL,
) )
# Print the response # Print the response
print(response.completion_message.content) print(response.completion_message.content)
``` ```

View file

@ -6,9 +6,10 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Union from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.llama3.api.datatypes import BaseModel, Field
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig

View file

@ -47,7 +47,7 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@ -55,5 +55,5 @@ class Scoring(Protocol):
async def score( async def score(
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -126,7 +126,7 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--templat
EOF EOF
printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile" printf "Dockerfile created successfully in $TEMP_DIR/Dockerfile\n\n"
cat $TEMP_DIR/Dockerfile cat $TEMP_DIR/Dockerfile
printf "\n" printf "\n"

View file

@ -39,6 +39,7 @@ from llama_stack.distribution.server.endpoints import get_all_api_endpoints
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
get_stack_run_config_from_template, get_stack_run_config_from_template,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
) )
@ -273,7 +274,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
console = Console() console = Console()
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
console.print(yaml.dump(self.config.model_dump(), indent=2))
# Redact sensitive information before printing
safe_config = redact_sensitive_fields(self.config.model_dump())
console.print(yaml.dump(safe_config, indent=2))
endpoints = get_all_api_endpoints() endpoints = get_all_api_endpoints()
endpoint_impls = {} endpoint_impls = {}

View file

@ -35,6 +35,7 @@ from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.stack import ( from llama_stack.distribution.stack import (
construct_stack, construct_stack,
redact_sensitive_fields,
replace_env_vars, replace_env_vars,
validate_env_pair, validate_env_pair,
) )
@ -280,7 +281,8 @@ def main():
config = StackRunConfig(**config) config = StackRunConfig(**config)
print("Run configuration:") print("Run configuration:")
print(yaml.dump(config.model_dump(), indent=2)) safe_config = redact_sensitive_fields(config.model_dump())
print(yaml.dump(safe_config, indent=2))
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.add_middleware(TracingMiddleware) app.add_middleware(TracingMiddleware)

View file

@ -112,6 +112,26 @@ class EnvVarError(Exception):
) )
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
"""Redact sensitive information from config before printing."""
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
result = {}
for k, v in d.items():
if isinstance(v, dict):
result[k] = _redact_dict(v)
elif isinstance(v, list):
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
elif any(pattern in k.lower() for pattern in sensitive_patterns):
result[k] = "********"
else:
result[k] = v
return result
return _redact_dict(data)
def replace_env_vars(config: Any, path: str = "") -> Any: def replace_env_vars(config: Any, path: str = "") -> Any:
if isinstance(config, dict): if isinstance(config, dict):
result = {} result = {}

View file

@ -129,7 +129,7 @@ def application_evaluation_page():
# Display current row results using separate containers # Display current row results using separate containers
progress_text_container.write( progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})" f"Expand to see current processed result ({i + 1} / {len(rows)})"
) )
results_container.json( results_container.json(
score_res.to_json(), score_res.to_json(),

View file

@ -232,7 +232,7 @@ def run_evaluation_3():
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0]) output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
progress_text_container.write( progress_text_container.write(
f"Expand to see current processed result ({i+1}/{len(rows)})" f"Expand to see current processed result ({i + 1} / {len(rows)})"
) )
results_container.json(eval_res, expanded=2) results_container.json(eval_res, expanded=2)

View file

@ -584,7 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_call = message.tool_calls[0] tool_call = message.tool_calls[0]
name = tool_call.tool_name name = tool_call.tool_name
if not isinstance(name, BuiltinTool): if not isinstance(name, BuiltinTool) or name not in enabled_tools:
yield message yield message
return return

View file

@ -3,23 +3,24 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from tqdm import tqdm from tqdm import tqdm
from llama_stack.apis.agents import Agents from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job from .....apis.common.job_types import Job
@ -30,15 +31,7 @@ from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:" EVAL_TASKS_PREFIX = "eval_tasks:"
class ColumnName(Enum): class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate, DataSchemaValidatorMixin):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__( def __init__(
self, self,
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
@ -82,29 +75,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
) )
self.eval_tasks[task_def.identifier] = task_def self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_eval( async def run_eval(
self, self,
task_id: str, task_id: str,
@ -114,8 +84,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=( rows_in_page=(
@ -167,11 +139,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
) )
] ]
final_event = turn_response[-1].event.payload final_event = turn_response[-1].event.payload
generations.append(
{ # check if there's a memory retrieval step and extract the context
ColumnName.generated_answer.value: final_event.turn.output_message.content memory_rag_context = None
} for step in final_event.turn.steps:
if step.step_type == StepType.memory_retrieval.value:
memory_rag_context = " ".join(x.text for x in step.inserted_context)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
) )
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations return generations

View file

@ -18,6 +18,7 @@ from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType, StringType from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from pydantic import BaseModel from pydantic import BaseModel
from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3 import llama3_tokenizer

View file

@ -7,6 +7,7 @@
import logging import logging
import os import os
import time import time
from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple

View file

@ -14,8 +14,13 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import BasicScoringConfig from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
@ -24,7 +29,9 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn] FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class BasicScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__( def __init__(
self, self,
config: BasicScoringConfig, config: BasicScoringConfig,
@ -61,30 +68,17 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,

View file

@ -9,12 +9,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn): class EqualityScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise. A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
""" """

View file

@ -9,14 +9,14 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import ( from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer, regex_parser_multiple_choice_answer,
) )
class RegexParserScoringFn(BaseScoringFn): class RegexParserScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that parses answer from generated response according to context and check match with expected_answer. A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
""" """

View file

@ -8,12 +8,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.subset_of import subset_of from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn): class SubsetOfScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise. A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
""" """

View file

@ -7,7 +7,17 @@ import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from autoevals.llm import Factuality from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness from autoevals.ragas import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
ContextEntityRecall,
ContextPrecision,
ContextRecall,
ContextRelevancy,
Faithfulness,
)
from pydantic import BaseModel
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
@ -18,20 +28,90 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
ScoringResultRow, ScoringResultRow,
) )
from llama_stack.apis.scoring_functions import AggregationFunctionType, ScoringFn from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def from .scoring_fn.fn_defs.factuality import factuality_fn_def
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
class BraintrustScoringFnEntry(BaseModel):
identifier: str
evaluator: Any
fn_def: ScoringFn
SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
BraintrustScoringFnEntry(
identifier="braintrust::factuality",
evaluator=Factuality(),
fn_def=factuality_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-correctness",
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-similarity",
evaluator=AnswerSimilarity(),
fn_def=answer_similarity_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::faithfulness",
evaluator=Faithfulness(),
fn_def=faithfulness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-entity-recall",
evaluator=ContextEntityRecall(),
fn_def=context_entity_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-precision",
evaluator=ContextPrecision(),
fn_def=context_precision_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-recall",
evaluator=ContextRecall(),
fn_def=context_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-relevancy",
evaluator=ContextRelevancy(),
fn_def=context_relevancy_fn_def,
),
]
class BraintrustScoringImpl( class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
DataSchemaValidatorMixin,
): ):
def __init__( def __init__(
self, self,
@ -44,12 +124,12 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api self.datasets_api = datasets_api
self.braintrust_evaluators = { self.braintrust_evaluators = {
"braintrust::factuality": Factuality(), entry.identifier: entry.evaluator
"braintrust::answer-correctness": AnswerCorrectness(), for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
self.supported_fn_defs_registry = { self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def, entry.identifier: entry.fn_def
answer_correctness_fn_def.identifier: answer_correctness_fn_def, for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
} }
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
@ -70,23 +150,6 @@ class BraintrustScoringImpl(
"Registering scoring function not allowed for braintrust provider" "Registering scoring function not allowed for braintrust provider"
) )
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None: async def set_api_key(self) -> None:
# api key is in the request headers # api key is in the request headers
if not self.config.openai_api_key: if not self.config.openai_api_key:
@ -102,11 +165,16 @@ class BraintrustScoringImpl(
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.set_api_key() await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
@ -126,6 +194,7 @@ class BraintrustScoringImpl(
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow: ) -> ScoringResultRow:
self.validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key() await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None" assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
@ -133,12 +202,19 @@ class BraintrustScoringImpl(
input_query = input_row["input_query"] input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier] evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(generated_answer, expected_answer, input=input_query) result = evaluator(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score score = result.score
return {"score": score, "metadata": result.metadata} return {"score": score, "metadata": result.metadata}
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse: ) -> ScoreResponse:
await self.set_api_key() await self.set_api_key()
res = {} res = {}
@ -150,8 +226,17 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id) await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows for input_row in input_rows
] ]
aggregation_functions = [AggregationFunctionType.average] aggregation_functions = self.supported_fn_defs_registry[
agg_results = aggregate_average(score_results) scoring_fn_id
].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:
override_params = scoring_functions[scoring_fn_id]
if override_params.aggregation_functions:
aggregation_functions = override_params.aggregation_functions
agg_results = aggregate_metrics(score_results, aggregation_functions)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,
aggregated_results=agg_results, aggregated_results=agg_results,

View file

@ -5,14 +5,23 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_correctness_fn_def = ScoringFn( answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness", identifier="braintrust::answer-correctness",
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", description=(
params=None, "Scores the correctness of the answer based on the ground truth. "
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="answer-correctness", provider_resource_id="answer-correctness",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Test output relevancy against the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_similarity_fn_def = ScoringFn(
identifier="braintrust::answer-similarity",
description=(
"Test output similarity against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_entity_recall_fn_def = ScoringFn(
identifier="braintrust::context-entity-recall",
description=(
"Evaluates how well the context captures the named entities present in the "
"reference answer. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_precision_fn_def = ScoringFn(
identifier="braintrust::context-precision",
description=(
"Measures how much of the provided context is actually relevant to answering the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_recall_fn_def = ScoringFn(
identifier="braintrust::context-recall",
description=(
"Evaluates how well the context covers the information needed to answer the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -5,14 +5,23 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
factuality_fn_def = ScoringFn( factuality_fn_def = ScoringFn(
identifier="braintrust::factuality", identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py", description=(
params=None, "Test output factuality against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust", provider_id="braintrust",
provider_resource_id="factuality", provider_resource_id="factuality",
return_type=NumberType(), return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
) )

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
faithfulness_fn_def = ScoringFn(
identifier="braintrust::faithfulness",
description=(
"Test output faithfulness to the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -16,7 +16,12 @@ from llama_stack.apis.scoring import (
ScoringResult, ScoringResult,
) )
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
DataSchemaValidatorMixin,
get_valid_schemas,
)
from .config import LlmAsJudgeScoringConfig from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
@ -25,7 +30,9 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): class LlmAsJudgeScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, DataSchemaValidatorMixin
):
def __init__( def __init__(
self, self,
config: LlmAsJudgeScoringConfig, config: LlmAsJudgeScoringConfig,
@ -65,30 +72,17 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None: async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet") raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
self.validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,

View file

@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(BaseScoringFn): class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
""" """
A scoring_fn that assigns A scoring_fn that assigns
""" """

View file

@ -71,7 +71,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras( self.client = AsyncCerebras(
base_url=self.config.base_url, api_key=self.config.api_key base_url=self.config.base_url,
api_key=self.config.api_key.get_secret_value(),
) )
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -8,7 +8,7 @@ import os
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
DEFAULT_BASE_URL = "https://api.cerebras.ai" DEFAULT_BASE_URL = "https://api.cerebras.ai"
@ -19,7 +19,7 @@ class CerebrasImplConfig(BaseModel):
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL), default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API", description="Base URL for the Cerebras API",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=os.environ.get("CEREBRAS_API_KEY"), default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key", description="Cerebras API Key",
) )

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -16,7 +16,7 @@ class FireworksImplConfig(BaseModel):
default="https://api.fireworks.ai/inference/v1", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=None, default=None,
description="The Fireworks.ai API Key", description="The Fireworks.ai API Key",
) )

View file

@ -113,7 +113,7 @@ class FireworksInferenceAdapter(
def _get_api_key(self) -> str: def _get_api_key(self) -> str:
if self.config.api_key is not None: if self.config.api_key is not None:
return self.config.api_key return self.config.api_key.get_secret_value()
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key: if provider_data is None or not provider_data.fireworks_api_key:

View file

@ -8,7 +8,7 @@ import os
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -40,7 +40,7 @@ class NVIDIAConfig(BaseModel):
), ),
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"), default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key, only needed of using the hosted service", description="The NVIDIA API key, only needed of using the hosted service",
) )

View file

@ -113,7 +113,11 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# make sure the client lives longer than any async calls # make sure the client lives longer than any async calls
self._client = AsyncOpenAI( self._client = AsyncOpenAI(
base_url=f"{self._config.url}/v1", base_url=f"{self._config.url}/v1",
api_key=self._config.api_key or "NO KEY", api_key=(
self._config.api_key.get_secret_value()
if self._config.api_key
else "NO KEY"
),
timeout=self._config.timeout, timeout=self._config.timeout,
) )

View file

@ -236,6 +236,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format=tool_prompt_format, tool_prompt_format=tool_prompt_format,
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
response_format=response_format,
) )
if stream: if stream:
return self._stream_chat_completion(request) return self._stream_chat_completion(request)
@ -279,6 +280,14 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
) )
input_dict["raw"] = True input_dict["raw"] = True
if fmt := request.response_format:
if fmt.type == "json_schema":
input_dict["format"] = fmt.json_schema
elif fmt.type == "grammar":
raise NotImplementedError("Grammar response format is not supported")
else:
raise ValueError(f"Unknown response format type: {fmt.type}")
return { return {
"model": request.model, "model": request.model,
**input_dict, **input_dict,

View file

@ -7,7 +7,7 @@
from typing import Optional from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -15,7 +15,7 @@ class TGIImplConfig(BaseModel):
url: str = Field( url: str = Field(
description="The URL for the TGI serving endpoint", description="The URL for the TGI serving endpoint",
) )
api_token: Optional[str] = Field( api_token: Optional[SecretStr] = Field(
default=None, default=None,
description="A bearer token if your TGI endpoint is protected.", description="A bearer token if your TGI endpoint is protected.",
) )
@ -32,7 +32,7 @@ class InferenceEndpointImplConfig(BaseModel):
endpoint_name: str = Field( endpoint_name: str = Field(
description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.", description="The name of the Hugging Face Inference Endpoint in the format of '{namespace}/{endpoint_name}' (e.g. 'my-cool-org/meta-llama-3-1-8b-instruct-rce'). Namespace is optional and will default to the user account if not provided.",
) )
api_token: Optional[str] = Field( api_token: Optional[SecretStr] = Field(
default=None, default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)", description="Your Hugging Face user access token (will default to locally saved token if not provided)",
) )
@ -55,7 +55,7 @@ class InferenceAPIImplConfig(BaseModel):
huggingface_repo: str = Field( huggingface_repo: str = Field(
description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')", description="The model ID of the model on the Hugging Face Hub (e.g. 'meta-llama/Meta-Llama-3.1-70B-Instruct')",
) )
api_token: Optional[str] = Field( api_token: Optional[SecretStr] = Field(
default=None, default=None,
description="Your Hugging Face user access token (will default to locally saved token if not provided)", description="Your Hugging Face user access token (will default to locally saved token if not provided)",
) )

View file

@ -290,7 +290,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
class TGIAdapter(_HfAdapter): class TGIAdapter(_HfAdapter):
async def initialize(self, config: TGIImplConfig) -> None: async def initialize(self, config: TGIImplConfig) -> None:
log.info(f"Initializing TGI client with url={config.url}") log.info(f"Initializing TGI client with url={config.url}")
self.client = AsyncInferenceClient(model=config.url, token=config.api_token) self.client = AsyncInferenceClient(
model=config.url, token=config.api_token.get_secret_value()
)
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
self.model_id = endpoint_info["model_id"] self.model_id = endpoint_info["model_id"]
@ -299,7 +301,7 @@ class TGIAdapter(_HfAdapter):
class InferenceAPIAdapter(_HfAdapter): class InferenceAPIAdapter(_HfAdapter):
async def initialize(self, config: InferenceAPIImplConfig) -> None: async def initialize(self, config: InferenceAPIImplConfig) -> None:
self.client = AsyncInferenceClient( self.client = AsyncInferenceClient(
model=config.huggingface_repo, token=config.api_token model=config.huggingface_repo, token=config.api_token.get_secret_value()
) )
endpoint_info = await self.client.get_endpoint_info() endpoint_info = await self.client.get_endpoint_info()
self.max_tokens = endpoint_info["max_total_tokens"] self.max_tokens = endpoint_info["max_total_tokens"]
@ -309,7 +311,7 @@ class InferenceAPIAdapter(_HfAdapter):
class InferenceEndpointAdapter(_HfAdapter): class InferenceEndpointAdapter(_HfAdapter):
async def initialize(self, config: InferenceEndpointImplConfig) -> None: async def initialize(self, config: InferenceEndpointImplConfig) -> None:
# Get the inference endpoint details # Get the inference endpoint details
api = HfApi(token=config.api_token) api = HfApi(token=config.api_token.get_secret_value())
endpoint = api.get_inference_endpoint(config.endpoint_name) endpoint = api.get_inference_endpoint(config.endpoint_name)
# Wait for the endpoint to be ready (if not already) # Wait for the endpoint to be ready (if not already)

View file

@ -7,7 +7,7 @@
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, SecretStr
@json_schema_type @json_schema_type
@ -16,7 +16,7 @@ class TogetherImplConfig(BaseModel):
default="https://api.together.xyz/v1", default="https://api.together.xyz/v1",
description="The URL for the Together AI server", description="The URL for the Together AI server",
) )
api_key: Optional[str] = Field( api_key: Optional[SecretStr] = Field(
default=None, default=None,
description="The Together AI API Key", description="The Together AI API Key",
) )

View file

@ -130,7 +130,7 @@ class TogetherInferenceAdapter(
def _get_client(self) -> Together: def _get_client(self) -> Together:
together_api_key = None together_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
together_api_key = self.config.api_key together_api_key = self.config.api_key.get_secret_value()
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.together_api_key: if provider_data is None or not provider_data.together_api_key:

View file

@ -38,9 +38,15 @@ def data_url_from_file(file_path: str) -> str:
async def register_dataset( async def register_dataset(
datasets_impl: Datasets, for_generation=False, dataset_id="test_dataset" datasets_impl: Datasets,
for_generation=False,
for_rag=False,
dataset_id="test_dataset",
): ):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv" if for_rag:
test_file = Path(os.path.abspath(__file__)).parent / "test_rag_dataset.csv"
else:
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file)) test_url = data_url_from_file(str(test_file))
if for_generation: if for_generation:
@ -49,6 +55,13 @@ async def register_dataset(
"input_query": StringType(), "input_query": StringType(),
"chat_completion_input": ChatCompletionInputType(), "chat_completion_input": ChatCompletionInputType(),
} }
elif for_rag:
dataset_schema = {
"expected_answer": StringType(),
"input_query": StringType(),
"generated_answer": StringType(),
"context": StringType(),
}
else: else:
dataset_schema = { dataset_schema = {
"expected_answer": StringType(), "expected_answer": StringType(),

View file

@ -0,0 +1,6 @@
input_query,context,generated_answer,expected_answer
What is the capital of France?,"France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum.",London,Paris
Who is the CEO of Meta?,"Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies.",Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,"The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets.",Jupiter,Jupiter
What is the smallest country in the world?,"Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy.",China,Vatican City
What is the currency of Japan?,"Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871.",Yen,Yen
1 input_query context generated_answer expected_answer
2 What is the capital of France? France is a country in Western Europe with a population of about 67 million people. Its capital city has been a major European cultural center since the 17th century and is known for landmarks like the Eiffel Tower and the Louvre Museum. London Paris
3 Who is the CEO of Meta? Meta Platforms, formerly known as Facebook, is one of the world's largest technology companies. Founded by Mark Zuckerberg in 2004, the company has expanded to include platforms like Instagram, WhatsApp, and virtual reality technologies. Mark Zuckerberg Mark Zuckerberg
4 What is the largest planet in our solar system? The solar system consists of eight planets orbiting around the Sun. These planets, in order from the Sun, are Mercury, Venus, Earth, Mars, Jupiter, Saturn, Uranus, and Neptune. Gas giants are significantly larger than terrestrial planets. Jupiter Jupiter
5 What is the smallest country in the world? Independent city-states and micronations are among the world's smallest sovereign territories. Some notable examples include Monaco, San Marino, and Vatican City, which is an enclave within Rome, Italy. China Vatican City
6 What is the currency of Japan? Japan is an island country in East Asia with a rich cultural heritage and one of the world's largest economies. Its financial system has been established since the Meiji period, with its modern currency being introduced in 1871. Yen Yen

View file

@ -210,6 +210,7 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model) provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"inline::meta-reference", "inline::meta-reference",
"remote::ollama",
"remote::tgi", "remote::tgi",
"remote::together", "remote::together",
"remote::fireworks", "remote::fireworks",
@ -272,6 +273,7 @@ class TestInference:
provider = inference_impl.routing_table.get_provider_impl(inference_model) provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in ( if provider.__provider_spec__.provider_type not in (
"inline::meta-reference", "inline::meta-reference",
"remote::ollama",
"remote::fireworks", "remote::fireworks",
"remote::tgi", "remote::tgi",
"remote::together", "remote::together",

View file

@ -60,7 +60,7 @@ class TestScoring:
f"{provider_id} provider does not support scoring without params" f"{provider_id} provider does not support scoring without params"
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
@ -112,7 +112,7 @@ class TestScoring:
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models], scoring_stack[Api.models],
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
assert len(response) == 1 assert len(response) == 1
@ -173,7 +173,7 @@ class TestScoring:
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models], scoring_stack[Api.models],
) )
await register_dataset(datasets_impl) await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset", dataset_id="test_dataset",
rows_in_page=3, rows_in_page=3,

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from llama_stack.apis.common.type_system import (
ChatCompletionInputType,
CompletionInputType,
StringType,
)
from llama_stack.distribution.datatypes import Api
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
context = "context"
VALID_SCHEMAS_FOR_SCORING = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.generated_answer.value: StringType(),
ColumnName.context.value: StringType(),
},
]
VALID_SCHEMAS_FOR_EVAL = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
def get_valid_schemas(api_str: str):
if api_str == Api.scoring.value:
return VALID_SCHEMAS_FOR_SCORING
elif api_str == Api.eval.value:
return VALID_SCHEMAS_FOR_EVAL
else:
raise ValueError(f"Invalid API string: {api_str}")
class DataSchemaValidatorMixin:
def validate_dataset_schema(
self,
dataset_schema: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
if dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_schema} does not have a correct input schema in {expected_schemas}"
)
def validate_row_schema(
self,
input_row: Dict[str, Any],
expected_schemas: List[Dict[str, Any]],
):
for schema in expected_schemas:
if all(key in input_row for key in schema):
return
raise ValueError(
f"Input row {input_row} does not match any of the expected schemas in {expected_schemas}"
)

View file

@ -40,7 +40,6 @@ from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
InterleavedContentItem, InterleavedContentItem,
TextContentItem, TextContentItem,
URL,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -117,27 +116,31 @@ async def interleaved_content_convert_to_raw(
elif isinstance(c, TextContentItem): elif isinstance(c, TextContentItem):
return RawTextItem(text=c.text) return RawTextItem(text=c.text)
elif isinstance(c, ImageContentItem): elif isinstance(c, ImageContentItem):
# load image and return PIL version if c.url:
img = c.data # Load image bytes from URL
if isinstance(img, URL): if c.url.uri.startswith("data"):
if img.uri.startswith("data"): match = re.match(r"data:image/(\w+);base64,(.+)", c.url.uri)
match = re.match(r"data:image/(\w+);base64,(.+)", img.uri)
if not match: if not match:
raise ValueError("Invalid data URL format") raise ValueError(
f"Invalid data URL format, {c.url.uri[:40]}..."
)
_, image_data = match.groups() _, image_data = match.groups()
data = base64.b64decode(image_data) data = base64.b64decode(image_data)
elif img.uri.startswith("file://"): elif c.url.uri.startswith("file://"):
path = img.uri[len("file://") :] path = c.url.uri[len("file://") :]
with open(path, "rb") as f: with open(path, "rb") as f:
data = f.read() # type: ignore data = f.read() # type: ignore
elif img.uri.startswith("http"): elif c.url.uri.startswith("http"):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(img.uri) response = await client.get(c.url.uri)
data = response.content data = response.content
else: else:
raise ValueError("Unsupported URL type") raise ValueError("Unsupported URL type")
else: elif c.data:
data = c.data data = c.data
else:
raise ValueError("No data or URL provided")
return RawMediaItem(data=data) return RawMediaItem(data=data)
else: else:
raise ValueError(f"Unsupported content type: {type(c)}") raise ValueError(f"Unsupported content type: {type(c)}")

View file

@ -13,12 +13,51 @@ from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metr
class BaseScoringFn(ABC): class BaseScoringFn(ABC):
""" """
Base interface class for all native scoring_fns. Base interface class for Scoring Functions.
Each scoring_fn needs to implement the following methods: Each scoring function needs to implement the following methods:
- score_row(self, row) - score_row(self, row)
- aggregate(self, scoring_fn_results) - aggregate(self, scoring_fn_results)
""" """
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return self.__class__.__name__
@abstractmethod
async def score_row(
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
@abstractmethod
async def score(
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
raise NotImplementedError()
class RegisteredBaseScoringFn(BaseScoringFn):
"""
Interface for native scoring functions that are registered in LlamaStack.
"""
def __init__(self, *args, **kwargs) -> None: def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.supported_fn_defs_registry = {} self.supported_fn_defs_registry = {}

View file

@ -53,7 +53,7 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
combined_args = {} combined_args = {}
for i, arg in enumerate(args): for i, arg in enumerate(args):
param_name = ( param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}" param_names[i] if i < len(param_names) else f"position_{i + 1}"
) )
combined_args[param_name] = serialize_value(arg) combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items(): for k, v in kwargs.items():