Merge branch 'main' into vllm

This commit is contained in:
Fred Reiss 2025-01-08 15:47:58 -08:00 committed by GitHub
commit 73fede90a6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
175 changed files with 7948 additions and 876 deletions

View file

@ -13,19 +13,64 @@ import secrets
import string
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
from typing import AsyncGenerator, Dict, List, Optional, Tuple
from urllib.parse import urlparse
import httpx
from llama_models.llama3.api.datatypes import BuiltinTool
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.agents import (
AgentConfig,
AgentTool,
AgentTurnCreateRequest,
AgentTurnResponseEvent,
AgentTurnResponseEventType,
AgentTurnResponseStepCompletePayload,
AgentTurnResponseStepProgressPayload,
AgentTurnResponseStepStartPayload,
AgentTurnResponseStreamChunk,
AgentTurnResponseTurnCompletePayload,
AgentTurnResponseTurnStartPayload,
Attachment,
CodeInterpreterToolDefinition,
FunctionCallToolDefinition,
InferenceStep,
MemoryRetrievalStep,
MemoryToolDefinition,
PhotogenToolDefinition,
SearchToolDefinition,
ShieldCallStep,
StepType,
ToolExecutionStep,
Turn,
WolframAlphaToolDefinition,
)
from llama_stack.apis.common.content_types import InterleavedContent, TextContentItem
from llama_stack.apis.common.content_types import (
InterleavedContent,
TextContentItem,
URL,
)
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
CompletionMessage,
Inference,
Message,
SamplingParams,
StopReason,
SystemMessage,
ToolCallDelta,
ToolCallParseStatus,
ToolChoice,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.memory import Memory, MemoryBankDocument, QueryDocumentsResponse
from llama_stack.apis.memory_banks import MemoryBanks, VectorMemoryBankParams
from llama_stack.apis.safety import Safety
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content
@ -539,7 +584,7 @@ class ChatAgent(ShieldRunnerMixin):
tool_call = message.tool_calls[0]
name = tool_call.tool_name
if not isinstance(name, BuiltinTool):
if not isinstance(name, BuiltinTool) or name not in enabled_tools:
yield message
return

View file

@ -9,15 +9,26 @@ import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator
from typing import AsyncGenerator, List, Optional, Union
from termcolor import colored
from llama_stack.apis.inference import Inference
from llama_stack.apis.agents import (
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentTurnCreateRequest,
Attachment,
Session,
Turn,
)
from llama_stack.apis.inference import Inference, ToolResponseMessage, UserMessage
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl

View file

@ -10,9 +10,11 @@ import uuid
from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.apis.agents import Turn
from llama_stack.providers.utils.kvstore import KVStore
log = logging.getLogger(__name__)

View file

@ -7,8 +7,6 @@
from typing import List
from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
@ -16,7 +14,7 @@ from llama_stack.apis.agents import (
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import Message, UserMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
@ -64,7 +62,7 @@ async def llm_rag_query_generator(
model = config.model
message = UserMessage(content=content)
response = await inference_api.chat_completion(
model=model,
model_id=model,
messages=[message],
stream=False,
)

View file

@ -9,7 +9,9 @@ import logging
from typing import List
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
log = logging.getLogger(__name__)

View file

@ -8,10 +8,26 @@ from typing import AsyncIterator, List, Optional, Union
import pytest
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.agents import * # noqa: F403
from llama_stack.apis.agents import (
AgentConfig,
AgentTurnCreateRequest,
AgentTurnResponseTurnCompletePayload,
)
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseStreamChunk,
CompletionMessage,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
UserMessage,
)
from llama_stack.apis.memory import MemoryBank
from llama_stack.apis.safety import RunShieldResponse
from ..agents import (
AGENT_INSTANCES_BY_ID,

View file

@ -7,7 +7,7 @@
from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.safety import Safety
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import * # noqa: F401, F403
from pydantic import BaseModel
class LocalFSDatasetIOConfig(BaseModel): ...

View file

@ -3,18 +3,19 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url

View file

@ -3,37 +3,38 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.agents import Agents
from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval_tasks import EvalTask
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference import Inference, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from .config import MetaReferenceEvalConfig
EVAL_TASKS_PREFIX = "eval_tasks:"
class ColumnName(Enum):
input_query = "input_query"
expected_answer = "expected_answer"
chat_completion_input = "chat_completion_input"
completion_input = "completion_input"
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
class MetaReferenceEvalImpl(
Eval,
EvalTasksProtocolPrivate,
):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -77,29 +78,6 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
)
self.eval_tasks[task_def.identifier] = task_def
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
expected_schemas = [
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.chat_completion_input.value: ChatCompletionInputType(),
},
{
ColumnName.input_query.value: StringType(),
ColumnName.expected_answer.value: StringType(),
ColumnName.completion_input.value: CompletionInputType(),
},
]
if dataset_def.dataset_schema not in expected_schemas:
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def run_eval(
self,
task_id: str,
@ -109,8 +87,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=(
@ -162,11 +142,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
)
]
final_event = turn_response[-1].event.payload
generations.append(
{
ColumnName.generated_answer.value: final_event.turn.output_message.content
}
# check if there's a memory retrieval step and extract the context
memory_rag_context = None
for step in final_event.turn.steps:
if step.step_type == StepType.memory_retrieval.value:
memory_rag_context = " ".join(x.text for x in step.inserted_context)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = (
final_event.turn.output_message.content
)
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations

View file

@ -6,11 +6,10 @@
from typing import Any, Dict, Optional
from llama_models.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, field_validator
from llama_stack.apis.inference import QuantizationConfig
from llama_stack.providers.utils.inference import supported_inference_models

View file

@ -32,11 +32,16 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from pydantic import BaseModel
from llama_stack.apis.inference import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
ResponseFormat,
ResponseFormatType,
)
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
@ -44,12 +49,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
CompletionRequestWithRawContent,
)
from .config import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
log = logging.getLogger(__name__)

View file

@ -14,7 +14,10 @@ from llama_models.llama3.api.datatypes import Model
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .config import MetaReferenceInferenceConfig
from .generation import Llama, model_checkpoint_dir
@ -27,9 +30,9 @@ class ModelRunner:
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequest):
if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequest):
elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req)
else:
raise ValueError(f"Unexpected task type {type(req)}")
@ -100,7 +103,7 @@ class LlamaModelParallelGenerator:
def completion(
self,
request: CompletionRequest,
request: CompletionRequestWithRawContent,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
@ -108,7 +111,7 @@ class LlamaModelParallelGenerator:
def chat_completion(
self,
request: ChatCompletionRequest,
request: ChatCompletionRequestWithRawContent,
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)

View file

@ -34,7 +34,10 @@ from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from llama_stack.providers.utils.inference.prompt_adapter import (
ChatCompletionRequestWithRawContent,
CompletionRequestWithRawContent,
)
from .generation import TokenResult
@ -79,7 +82,7 @@ class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = (
ProcessingMessageName.task_request
)
task: Union[CompletionRequest, ChatCompletionRequest]
task: Union[CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent]
class TaskResponse(BaseModel):
@ -264,9 +267,6 @@ def launch_dist_group(
init_model_cb: Callable,
**kwargs,
) -> None:
id = uuid.uuid4().hex
dist_url = f"file:///tmp/llama3_{id}_{time.time()}"
with tempfile.TemporaryDirectory() as tmpdir:
# TODO: track workers and if they terminate, tell parent process about it so cleanup can happen
launch_config = LaunchConfig(
@ -315,7 +315,7 @@ def start_model_parallel_process(
# wait until the model is loaded; rank 0 will send a message to indicate it's ready
request_socket.send(encode_msg(ReadyRequest()))
response = request_socket.recv()
_response = request_socket.recv()
log.info("Loaded model...")
return request_socket, process
@ -349,7 +349,10 @@ class ModelParallelProcessGroup:
self.started = False
def run_inference(
self, req: Union[CompletionRequest, ChatCompletionRequest]
self,
req: Union[
CompletionRequestWithRawContent, ChatCompletionRequestWithRawContent
],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -7,10 +7,10 @@
import logging
import os
import uuid
from typing import AsyncGenerator, Optional
from typing import AsyncGenerator, List, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
@ -18,9 +18,26 @@ from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,

View file

@ -16,11 +16,14 @@ import faiss
import numpy as np
from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.memory_banks import MemoryBankType, VectorMemoryBank
from llama_stack.apis.memory import (
Chunk,
Memory,
MemoryBankDocument,
QueryDocumentsResponse,
)
from llama_stack.apis.memory_banks import MemoryBank, MemoryBankType, VectorMemoryBank
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import (

View file

@ -90,18 +90,24 @@ class TorchtuneCheckpointer:
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
source_path = Path.joinpath(self._checkpoint_dir, "params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "params.json"),
)
source_path = Path.joinpath(self._checkpoint_dir, "tokenizer.model")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "tokenizer.model"),
)
source_path = Path.joinpath(self._checkpoint_dir, "orig_params.json")
if source_path.exists():
shutil.copy(
source_path,
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]

View file

@ -14,14 +14,16 @@ from enum import Enum
from typing import Any, Callable, Dict, List
import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.common.type_system import ParamType, StringType
from llama_stack.apis.datasets import Datasets
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from pydantic import BaseModel
from torchtune.models.llama3 import llama3_tokenizer
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_1 import lora_llama3_1_8b
from torchtune.models.llama3_2 import lora_llama3_2_3b
@ -48,8 +50,8 @@ MODEL_CONFIGS: Dict[str, ModelConfig] = {
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2",
),
"Llama-3-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b,
"Llama3.1-8B-Instruct": ModelConfig(
model_definition=lora_llama3_1_8b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3",
),

View file

@ -3,11 +3,26 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import Any, Dict, List, Optional
from llama_models.schema_utils import webmethod
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
LoraFinetuningConfig,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)

View file

@ -7,6 +7,7 @@
import logging
import os
import time
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
@ -14,27 +15,33 @@ from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.training_types import PostTrainingMetric
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.post_training import (
AlgorithmConfig,
Checkpoint,
LoraFinetuningConfig,
OptimizerConfig,
TrainingConfig,
)
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune import modules, training, utils as torchtune_utils
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
@ -43,11 +50,12 @@ from torchtune.modules.peft import (
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
log = logging.getLogger(__name__)
@ -110,6 +118,10 @@ class LoraFinetuningSingleDevice:
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
if model is None:
raise ValueError(
f"{self.model_id} not found. Your model id should be in the llama models SKU list"
)
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
@ -125,6 +137,7 @@ class LoraFinetuningSingleDevice:
self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self.max_validation_steps = training_config.max_validation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
@ -277,7 +290,6 @@ class LoraFinetuningSingleDevice:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
@ -572,7 +584,7 @@ class LoraFinetuningSingleDevice:
log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader):
if idx == 10:
if idx == self.max_validation_steps:
break
torchtune_utils.batch_to_device(batch, self._device)

View file

@ -7,8 +7,14 @@
import logging
from typing import Any, Dict, List
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)

View file

@ -9,10 +9,24 @@ import re
from string import Template
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.datatypes import CoreModelId
from llama_models.llama3.api.datatypes import Role
from llama_stack.apis.common.content_types import ImageContentItem, TextContentItem
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
Inference,
Message,
UserMessage,
)
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ShieldsProtocolPrivate

View file

@ -11,11 +11,16 @@ import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import (
RunShieldResponse,
Safety,
SafetyViolation,
ViolationLevel,
)
from llama_stack.apis.shields import Shield
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,

View file

@ -3,16 +3,24 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import BasicScoringConfig
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
@ -21,7 +29,10 @@ from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
FIXED_FNS = [EqualityScoringFn, SubsetOfScoringFn, RegexParserScoringFn]
class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BasicScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
config: BasicScoringConfig,
@ -58,30 +69,17 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,

View file

@ -9,12 +9,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.equality import equality
class EqualityScoringFn(BaseScoringFn):
class EqualityScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""

View file

@ -9,14 +9,14 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
)
class RegexParserScoringFn(BaseScoringFn):
class RegexParserScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that parses answer from generated response according to context and check match with expected_answer.
"""

View file

@ -8,12 +8,12 @@ from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.subset_of import subset_of
class SubsetOfScoringFn(BaseScoringFn):
class SubsetOfScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns a score of 1.0 if the expected string is included in the generated string, and 0.0 otherwise.
"""

View file

@ -3,32 +3,115 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
import os
from typing import Any, Dict, List, Optional
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from autoevals.ragas import (
AnswerCorrectness,
AnswerRelevancy,
AnswerSimilarity,
ContextEntityRecall,
ContextPrecision,
ContextRecall,
ContextRelevancy,
Faithfulness,
)
from pydantic import BaseModel
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.scoring import (
ScoreBatchResponse,
ScoreResponse,
Scoring,
ScoringResult,
ScoringResultRow,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
validate_row_schema,
)
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
from .config import BraintrustScoringConfig
from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def
from .scoring_fn.fn_defs.answer_similarity import answer_similarity_fn_def
from .scoring_fn.fn_defs.context_entity_recall import context_entity_recall_fn_def
from .scoring_fn.fn_defs.context_precision import context_precision_fn_def
from .scoring_fn.fn_defs.context_recall import context_recall_fn_def
from .scoring_fn.fn_defs.context_relevancy import context_relevancy_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
from .scoring_fn.fn_defs.faithfulness import faithfulness_fn_def
class BraintrustScoringFnEntry(BaseModel):
identifier: str
evaluator: Any
fn_def: ScoringFn
SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY = [
BraintrustScoringFnEntry(
identifier="braintrust::factuality",
evaluator=Factuality(),
fn_def=factuality_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-correctness",
evaluator=AnswerCorrectness(),
fn_def=answer_correctness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-relevancy",
evaluator=AnswerRelevancy(),
fn_def=answer_relevancy_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::answer-similarity",
evaluator=AnswerSimilarity(),
fn_def=answer_similarity_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::faithfulness",
evaluator=Faithfulness(),
fn_def=faithfulness_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-entity-recall",
evaluator=ContextEntityRecall(),
fn_def=context_entity_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-precision",
evaluator=ContextPrecision(),
fn_def=context_precision_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-recall",
evaluator=ContextRecall(),
fn_def=context_recall_fn_def,
),
BraintrustScoringFnEntry(
identifier="braintrust::context-relevancy",
evaluator=ContextRelevancy(),
fn_def=context_relevancy_fn_def,
),
]
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
Scoring,
ScoringFunctionsProtocolPrivate,
NeedsRequestProviderData,
):
def __init__(
self,
@ -41,12 +124,12 @@ class BraintrustScoringImpl(
self.datasets_api = datasets_api
self.braintrust_evaluators = {
"braintrust::factuality": Factuality(),
"braintrust::answer-correctness": AnswerCorrectness(),
entry.identifier: entry.evaluator
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
self.supported_fn_defs_registry = {
factuality_fn_def.identifier: factuality_fn_def,
answer_correctness_fn_def.identifier: answer_correctness_fn_def,
entry.identifier: entry.fn_def
for entry in SUPPORTED_BRAINTRUST_SCORING_FN_ENTRY
}
async def initialize(self) -> None: ...
@ -67,23 +150,6 @@ class BraintrustScoringImpl(
"Registering scoring function not allowed for braintrust provider"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None:
# api key is in the request headers
if not self.config.openai_api_key:
@ -99,11 +165,16 @@ class BraintrustScoringImpl(
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
@ -123,6 +194,7 @@ class BraintrustScoringImpl(
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
validate_row_schema(input_row, get_valid_schemas(Api.scoring.value))
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
@ -130,12 +202,19 @@ class BraintrustScoringImpl(
input_query = input_row["input_query"]
evaluator = self.braintrust_evaluators[scoring_fn_identifier]
result = evaluator(generated_answer, expected_answer, input=input_query)
result = evaluator(
generated_answer,
expected_answer,
input=input_query,
context=input_row["context"] if "context" in input_row else None,
)
score = result.score
return {"score": score, "metadata": result.metadata}
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]],
) -> ScoreResponse:
await self.set_api_key()
res = {}
@ -147,8 +226,17 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results)
aggregation_functions = self.supported_fn_defs_registry[
scoring_fn_id
].params.aggregation_functions
# override scoring_fn params if provided
if scoring_functions[scoring_fn_id] is not None:
override_params = scoring_functions[scoring_fn_id]
if override_params.aggregation_functions:
aggregation_functions = override_params.aggregation_functions
agg_results = aggregate_metrics(score_results, aggregation_functions)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -3,7 +3,9 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class BraintrustScoringConfig(BaseModel):

View file

@ -5,14 +5,23 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_correctness_fn_def = ScoringFn(
identifier="braintrust::answer-correctness",
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
params=None,
description=(
"Scores the correctness of the answer based on the ground truth. "
"Uses Braintrust LLM-based scorer from autoevals library."
),
provider_id="braintrust",
provider_resource_id="answer-correctness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_relevancy_fn_def = ScoringFn(
identifier="braintrust::answer-relevancy",
description=(
"Test output relevancy against the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
answer_similarity_fn_def = ScoringFn(
identifier="braintrust::answer-similarity",
description=(
"Test output similarity against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="answer-similarity",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_entity_recall_fn_def = ScoringFn(
identifier="braintrust::context-entity-recall",
description=(
"Evaluates how well the context captures the named entities present in the "
"reference answer. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-entity-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_precision_fn_def = ScoringFn(
identifier="braintrust::context-precision",
description=(
"Measures how much of the provided context is actually relevant to answering the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-precision",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_recall_fn_def = ScoringFn(
identifier="braintrust::context-recall",
description=(
"Evaluates how well the context covers the information needed to answer the "
"question. See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-recall",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
context_relevancy_fn_def = ScoringFn(
identifier="braintrust::context-relevancy",
description=(
"Assesses how relevant the provided context is to the given question. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="context-relevancy",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -5,14 +5,23 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
factuality_fn_def = ScoringFn(
identifier="braintrust::factuality",
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
params=None,
description=(
"Test output factuality against expected value using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="factuality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
faithfulness_fn_def = ScoringFn(
identifier="braintrust::faithfulness",
description=(
"Test output faithfulness to the input query using Braintrust LLM scorer. "
"See: github.com/braintrustdata/autoevals"
),
provider_id="braintrust",
provider_resource_id="faithfulness",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.average]
),
)

View file

@ -16,7 +16,12 @@ from llama_stack.apis.scoring import (
ScoringResult,
)
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.common.data_schema_validator import (
get_valid_schemas,
validate_dataset_schema,
)
from .config import LlmAsJudgeScoringConfig
from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
@ -25,7 +30,10 @@ from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn
LLM_JUDGE_FNS = [LlmAsJudgeScoringFn]
class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class LlmAsJudgeScoringImpl(
Scoring,
ScoringFunctionsProtocolPrivate,
):
def __init__(
self,
config: LlmAsJudgeScoringConfig,
@ -65,30 +73,17 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def register_scoring_function(self, function_def: ScoringFn) -> None:
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(
dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value)
)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,

View file

@ -12,14 +12,14 @@ from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
from .fn_defs.llm_as_judge_base import llm_as_judge_base
class LlmAsJudgeScoringFn(BaseScoringFn):
class LlmAsJudgeScoringFn(RegisteredBaseScoringFn):
"""
A scoring_fn that assigns
"""

View file

@ -17,6 +17,22 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.apis.telemetry import (
Event,
MetricEvent,
QueryCondition,
SpanEndPayload,
SpanStartPayload,
SpanStatus,
SpanWithStatus,
StructuredLogEvent,
Telemetry,
Trace,
UnstructuredLogEvent,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
@ -27,10 +43,6 @@ from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
@ -100,8 +112,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
async def shutdown(self) -> None:
trace.get_tracer_provider().force_flush()
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):

View file

@ -4,12 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.telemetry import Telemetry
from .config import SampleConfig
from llama_stack.apis.telemetry import * # noqa: F403
class SampleTelemetryImpl(Telemetry):
def __init__(self, config: SampleConfig):
self.config = config

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .brave_search import BraveSearchToolRuntimeImpl
from .config import BraveSearchToolConfig
class BraveSearchToolProviderDataValidator(BaseModel):
api_key: str
async def get_provider_impl(config: BraveSearchToolConfig, _deps):
impl = BraveSearchToolRuntimeImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,123 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.tools import Tool, ToolGroupDef, ToolInvocationResult, ToolRuntime
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from .config import BraveSearchToolConfig
class BraveSearchToolRuntimeImpl(
ToolsProtocolPrivate, ToolRuntime, NeedsRequestProviderData
):
def __init__(self, config: BraveSearchToolConfig):
self.config = config
async def initialize(self):
pass
async def register_tool(self, tool: Tool):
if tool.identifier != "brave_search":
raise ValueError(f"Tool identifier {tool.identifier} is not supported")
async def unregister_tool(self, tool_id: str) -> None:
return
def _get_api_key(self) -> str:
if self.config.api_key:
return self.config.api_key
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.api_key:
raise ValueError(
'Pass Search provider\'s API Key in the header X-LlamaStack-ProviderData as { "api_key": <your api key>}'
)
return provider_data.api_key
async def discover_tools(self, tool_group: ToolGroupDef) -> List[Tool]:
raise NotImplementedError("Brave search tool group not supported")
async def invoke_tool(
self, tool_name: str, args: Dict[str, Any]
) -> ToolInvocationResult:
api_key = self._get_api_key()
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"X-Subscription-Token": api_key,
"Accept-Encoding": "gzip",
"Accept": "application/json",
}
payload = {"q": args["query"]}
response = requests.get(url=url, params=payload, headers=headers)
response.raise_for_status()
results = self._clean_brave_response(response.json())
content_items = "\n".join([str(result) for result in results])
return ToolInvocationResult(
content=content_items,
)
def _clean_brave_response(self, search_response):
clean_response = []
if "mixed" in search_response:
mixed_results = search_response["mixed"]
for m in mixed_results["main"][: self.config.max_results]:
r_type = m["type"]
results = search_response[r_type]["results"]
cleaned = self._clean_result_by_type(r_type, results, m.get("index"))
clean_response.append(cleaned)
return clean_response
def _clean_result_by_type(self, r_type, results, idx=None):
type_cleaners = {
"web": (
["type", "title", "url", "description", "date", "extra_snippets"],
lambda x: x[idx],
),
"faq": (["type", "question", "answer", "title", "url"], lambda x: x),
"infobox": (
["type", "title", "url", "description", "long_desc"],
lambda x: x[idx],
),
"videos": (["type", "url", "title", "description", "date"], lambda x: x),
"locations": (
[
"type",
"title",
"url",
"description",
"coordinates",
"postal_address",
"contact",
"rating",
"distance",
"zoom_level",
],
lambda x: x,
),
"news": (["type", "title", "url", "description"], lambda x: x),
}
if r_type not in type_cleaners:
return ""
selected_keys, result_selector = type_cleaners[r_type]
results = result_selector(results)
if isinstance(results, list):
cleaned = [
{k: v for k, v in item.items() if k in selected_keys}
for item in results
]
else:
cleaned = {k: v for k, v in results.items() if k in selected_keys}
return str(cleaned)

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel, Field
class BraveSearchToolConfig(BaseModel):
api_key: Optional[str] = Field(
default=None,
description="The Brave Search API Key",
)
max_results: int = Field(
default=3,
description="The maximum number of results to return",
)