Merge branch 'main' into add-nim-completion-api

This commit is contained in:
Matthew Farrellee 2024-12-11 13:07:23 -05:00
commit df3c239573
199 changed files with 7739 additions and 814 deletions

View file

@ -63,6 +63,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...

View file

@ -10,9 +10,7 @@ import logging
import os
import re
import secrets
import shutil
import string
import tempfile
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin):
self,
agent_id: str,
agent_config: AgentConfig,
tempdir: str,
inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin):
):
self.agent_id = agent_id
self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
@ -144,87 +139,91 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
steps.append(event.payload.step_details)
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
yield chunk
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
assert output_message is not None
yield chunk
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
assert output_message is not None
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
)
yield chunk
yield chunk
async def run(
self,
@ -273,7 +272,6 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
@ -281,23 +279,46 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
return
with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
)
)
)
await self.run_multiple_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
),
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
@ -305,30 +326,12 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
violation=None,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
),
)
)
)
span.set_attribute("output", "no violations")
async def _run(
self,
@ -356,10 +359,15 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -396,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -416,7 +419,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference"):
with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -436,14 +439,13 @@ class ChatAgent(ShieldRunnerMixin):
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
text_delta="",
tool_call_delta=delta,
)
)
@ -457,7 +459,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
text_delta=event.delta,
)
)
)
@ -466,6 +468,13 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
stop_reason = stop_reason or StopReason.out_of_tokens
@ -549,7 +558,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.span("tool_execution"):
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
@ -558,6 +573,7 @@ class ChatAgent(ShieldRunnerMixin):
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(

View file

@ -6,9 +6,13 @@
import json
import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator
from termcolor import colored
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
@ -40,10 +44,20 @@ class MetaReferenceAgentsImpl(Agents):
self.memory_banks_api = memory_banks_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
async def create_agent(
self,
agent_config: AgentConfig,
@ -82,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents):
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,

View file

@ -3,14 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import urlparse
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
@ -97,6 +100,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,
@ -128,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows),
next_page_token=str(end),
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)
url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
"utf-8"
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class MetaReferenceEvalConfig(BaseModel):

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
@ -17,7 +19,6 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from tqdm import tqdm
from .config import MetaReferenceEvalConfig

View file

@ -27,7 +27,6 @@ from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig
@ -95,7 +94,6 @@ class FaissIndex(EmbeddingIndex):
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ConsoleConfig
async def get_provider_impl(config: ConsoleConfig, _deps):
from .console import ConsoleTelemetryImpl
impl = ConsoleTelemetryImpl(config)
await impl.initialize()
return impl

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
@json_schema_type
class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.TEXT

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from typing import Optional
from typing import List, Optional
from .config import LogFormat
@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry):
if formatted:
print(formatted)
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError()
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
raise NotImplementedError("Console telemetry does not support trace querying")
async def get_spans(
self,
span_id: str,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> SpanWithChildren:
raise NotImplementedError("Console telemetry does not support span querying")
COLORS = {

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -113,7 +113,9 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from typing import Any, Dict, Optional
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.equality import equality
@ -42,8 +42,3 @@ class EqualityScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -5,14 +5,20 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
)

View file

@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -5,7 +5,11 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
subset_of = ScoringFn(
@ -14,4 +18,7 @@ subset_of = ScoringFn(
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
)

View file

@ -5,11 +5,11 @@
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
@ -60,8 +60,3 @@ class RegexParserScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.subset_of import subset_of
@ -36,8 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -5,11 +5,17 @@
# the root directory of this source tree.
from typing import Dict
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BraintrustScoringConfig
class BraintrustProviderDataValidator(BaseModel):
openai_api_key: str
async def get_provider_impl(
config: BraintrustScoringConfig,
deps: Dict[Api, ProviderSpec],

View file

@ -12,9 +12,12 @@ from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
@ -24,7 +27,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
from .scoring_fn.fn_defs.factuality import factuality_fn_def
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
class BraintrustScoringImpl(
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
):
def __init__(
self,
config: BraintrustScoringConfig,
@ -79,12 +84,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def set_api_key(self) -> None:
# api key is in the request headers
if not self.config.openai_api_key:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
)
self.config.openai_api_key = provider_data.openai_api_key
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.set_api_key()
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
@ -105,6 +123,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
) -> ScoringResultRow:
await self.set_api_key()
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
@ -118,6 +137,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
await self.set_api_key()
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in self.supported_fn_defs_registry:
@ -127,7 +147,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,

View file

@ -6,4 +6,14 @@
from llama_stack.apis.scoring import * # noqa: F401, F403
class BraintrustScoringConfig(BaseModel): ...
class BraintrustScoringConfig(BaseModel):
openai_api_key: Optional[str] = Field(
default=None,
description="The OpenAI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"openai_api_key": "${env.OPENAI_API_KEY:}",
}

View file

@ -120,7 +120,9 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn(
@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn(
return_type=NumberType(),
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template="Enter custom LLM as Judge Prompt Template",
),
)

View file

@ -3,13 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
@ -85,9 +88,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
"score": judge_rating,
"judge_feedback": content,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
# TODO: this needs to be config based aggregation, and only useful w/ Jobs API
return {}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
OTEL = "otel"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
)
sqlite_db_path: str = Field(
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
description="The path to the SQLite database to use for storing traces",
)
@field_validator("sinks", mode="before")
@classmethod
def validate_sinks(cls, v):
if isinstance(v, str):
return [TelemetrySink(sink.strip()) for sink in v.split(",")]
return v
@classmethod
def sample_run_config(
cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db"
) -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/"
+ __distro_dir__
+ "/"
+ db_name
+ "}",
}

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
class ConsoleSpanProcessor(SpanProcessor):
def __init__(self, print_attributes: bool = False):
self.print_attributes = print_attributes
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
if key.startswith("__"):
continue
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, (dict, list)):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(
f" {event_time} "
f"{msg_color}[{severity.upper()}] "
f"{message}{COLORS['reset']}"
)
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
def shutdown(self) -> None:
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import sqlite3
from datetime import datetime
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.conn = None
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get the database connection."""
if self.conn is None:
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
return self.conn
def setup_database(self):
"""Create the necessary tables if they don't exist."""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT PRIMARY KEY,
service_name TEXT,
root_span_id TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS spans (
span_id TEXT PRIMARY KEY,
trace_id TEXT REFERENCES traces(trace_id),
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes TEXT,
status TEXT,
kind TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS span_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
span_id TEXT REFERENCES spans(span_id),
name TEXT,
timestamp TIMESTAMP,
attributes TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_traces_created_at
ON traces(created_at)
"""
)
conn.commit()
cursor.close()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
# Insert into traces
cursor.execute(
"""
INSERT INTO traces (
trace_id, service_name, root_span_id, start_time, end_time
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(trace_id) DO UPDATE SET
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
start_time = MIN(excluded.start_time, start_time),
end_time = MAX(excluded.end_time, end_time)
""",
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
),
)
# Insert into spans
cursor.execute(
"""
INSERT INTO spans (
span_id, trace_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
span_id,
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
),
)
for event in span.events:
cursor.execute(
"""
INSERT INTO span_events (
span_id, name, timestamp, attributes
) VALUES (?, ?, ?, ?)
""",
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
json.dumps(dict(event.attributes)),
),
)
conn.commit()
cursor.close()
except Exception as e:
print(f"Error exporting span to SQLite: {e}")
def shutdown(self):
"""Cleanup any resources."""
if self.conn:
self.conn.close()
self.conn = None
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import threading
from typing import Any, Dict, List, Optional
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -16,10 +17,21 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
"active_spans": {},
@ -45,9 +57,10 @@ def is_tracing_enabled(tracer):
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps[Api.datasetio]
resource = Resource.create(
{
@ -57,22 +70,29 @@ class OpenTelemetryAdapter(Telemetry):
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None:
@ -83,15 +103,17 @@ class OpenTelemetryAdapter(Telemetry):
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
@ -104,6 +126,7 @@ class OpenTelemetryAdapter(Telemetry):
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
},
timestamp=timestamp_ns,
@ -154,11 +177,14 @@ class OpenTelemetryAdapter(Telemetry):
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent) -> None:
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
@ -170,7 +196,6 @@ class OpenTelemetryAdapter(Telemetry):
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
# Create a new trace context with the trace_id
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
@ -179,14 +204,9 @@ class OpenTelemetryAdapter(Telemetry):
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
# Set as current span using context manager
with trace.use_span(span, end_on_exit=False):
pass # Let the span continue beyond this block
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
@ -199,10 +219,33 @@ class OpenTelemetryAdapter(Telemetry):
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
# Remove from active spans
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError("Trace retrieval not implemented yet")
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
return await self.trace_store.query_traces(
attribute_filters=attribute_filters,
limit=limit,
offset=offset,
order_by=order_by,
)
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
return await self.trace_store.get_span_tree(
span_id=span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)

View file

@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -44,5 +44,6 @@ def available_providers() -> List[ProviderSpec]:
Api.datasetio,
Api.datasets,
],
provider_data_validator="llama_stack.providers.inline.scoring.braintrust.BraintrustProviderDataValidator",
),
]

View file

@ -14,9 +14,13 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.telemetry,
provider_type="inline::meta-reference",
pip_packages=[],
module="llama_stack.providers.inline.meta_reference.telemetry",
config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig",
pip_packages=[
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
],
api_dependencies=[Api.datasetio],
module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
),
remote_provider_spec(
api=Api.telemetry,
@ -27,18 +31,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_type="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-jaeger",
"opentelemetry-semantic-conventions",
],
module="llama_stack.providers.remote.telemetry.opentelemetry",
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
),
),
]

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import * # noqa: F403
@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,
@ -95,3 +100,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows),
next_page_token=str(end),
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CerebrasImplConfig
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
from .cerebras import CerebrasInferenceAdapter
assert isinstance(
config, CerebrasImplConfig
), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
model_aliases = [
build_model_alias(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3.1-70b",
CoreModelId.llama3_1_70b_instruct.value,
),
]
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras(
base_url=self.config.base_url, api_key=self.config.api_key
)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(
request,
)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and request.sampling_params.top_k:
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: Optional[str] = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": "${env.CEREBRAS_API_KEY}",
}

View file

@ -35,7 +35,9 @@ class NVIDIAConfig(BaseModel):
"""
url: str = Field(
default="https://integrate.api.nvidia.com",
default_factory=lambda: os.getenv(
"NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"
),
description="A base url for accessing the NVIDIA NIM",
)
api_key: Optional[str] = Field(

View file

@ -180,7 +180,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
@ -270,7 +269,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
assert isinstance(r, dict)
if "message" in r:
choice = OpenAICompatCompletionChoice(

View file

@ -100,6 +100,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
response_format=response_format,
)
if stream:
return self._stream_chat_completion(request, self.client)
@ -180,6 +181,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.formatter,
)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
input_dict["extra_body"] = {
"guided_json": request.response_format.json_schema
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": request.model,
**input_dict,

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
class OpenTelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}

View file

@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1
assert response[0].identifier == "test_dataset"
with pytest.raises(Exception) as exc_info:
# unregister a dataset that does not exist
await datasets_impl.unregister_dataset("test_dataset2")
await datasets_impl.unregister_dataset("test_dataset")
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(Exception) as exc_info:
await datasets_impl.unregister_dataset("test_dataset")
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack

View file

@ -6,10 +6,14 @@
import pytest
from ..agents.fixtures import AGENTS_FIXTURES
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..memory.fixtures import MEMORY_FIXTURES
from ..safety.fixtures import SAFETY_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES
@ -20,6 +24,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "localfs",
"inference": "fireworks",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
@ -30,6 +37,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "localfs",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
@ -40,6 +50,9 @@ DEFAULT_PROVIDER_COMBINATIONS = [
"scoring": "basic",
"datasetio": "huggingface",
"inference": "together",
"agents": "meta_reference",
"safety": "llama_guard",
"memory": "faiss",
},
id="meta_reference_eval_together_inference_huggingface_datasetio",
marks=pytest.mark.meta_reference_eval_together_inference_huggingface_datasetio,
@ -75,6 +88,9 @@ def pytest_generate_tests(metafunc):
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
"agents": AGENTS_FIXTURES,
"safety": SAFETY_FIXTURES,
"memory": MEMORY_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)

View file

@ -40,14 +40,30 @@ async def eval_stack(request):
providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
for key in [
"datasetio",
"eval",
"scoring",
"inference",
"agents",
"safety",
"memory",
]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
[
Api.eval,
Api.datasetio,
Api.inference,
Api.scoring,
Api.agents,
Api.safety,
Api.memory,
],
providers,
provider_data,
)

View file

@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
)
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_cerebras() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="cerebras",
provider_type="remote::cerebras",
config=CerebrasImplConfig(
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [
"vllm_remote",
"remote",
"bedrock",
"cerebras",
"nvidia",
"tgi",
]

View file

@ -95,6 +95,7 @@ class TestInference:
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -139,6 +140,8 @@ class TestInference:
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
@ -198,6 +201,7 @@ class TestInference:
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
@ -211,7 +215,15 @@ class TestInference:
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,

Binary file not shown.

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
from llama_stack.apis.memory.memory import MemoryBankDocument, URL
from llama_stack.providers.utils.memory.vector_store import content_from_doc
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
def read_file(file_path: str) -> bytes:
with open(file_path, "rb") as file:
return file.read()
def data_url_from_file(file_path: str) -> str:
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=URL(
uri=url,
),
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"

View file

@ -10,9 +10,10 @@ import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
from ..env import get_env_or_fail
@pytest.fixture(scope="session")
@ -40,7 +41,9 @@ def scoring_braintrust() -> ProviderFixture:
Provider(
provider_id="braintrust",
provider_type="inline::braintrust",
config={},
config=BraintrustScoringConfig(
openai_api_key=get_env_or_fail("OPENAI_API_KEY"),
).model_dump(),
)
],
)

View file

@ -7,7 +7,12 @@
import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
@ -18,6 +23,11 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase
# -v -s --tb=short --disable-warnings
@pytest.fixture
def sample_judge_prompt_template():
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
class TestScoring:
@pytest.mark.asyncio
async def test_scoring_functions_list(self, scoring_stack):
@ -92,7 +102,9 @@ class TestScoring:
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack):
async def test_scoring_score_with_params_llm_as_judge(
self, scoring_stack, sample_judge_prompt_template
):
(
scoring_impl,
scoring_functions_impl,
@ -129,10 +141,11 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
"llm-as-judge::base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=[AggregationFunctionType.categorical_count],
)
}
@ -154,3 +167,67 @@ class TestScoring:
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_aggregation_functions(
self, scoring_stack, sample_judge_prompt_template
):
(
scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl)
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
scoring_functions = {}
aggr_fns = [
AggregationFunctionType.accuracy,
AggregationFunctionType.median,
AggregationFunctionType.categorical_count,
AggregationFunctionType.average,
]
for x in scoring_fns_list:
if x.provider_id == "llm-as-judge":
aggr_fns = [AggregationFunctionType.categorical_count]
scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
elif x.provider_id == "basic":
if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = BasicScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = None
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
assert len(response.results[x].aggregated_results) == len(aggr_fns)

View file

@ -27,7 +27,8 @@ def supported_inference_models() -> List[Model]:
m
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
m.model_family
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
or is_supported_safety_model(m)
)
]

View file

@ -45,6 +45,13 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
return loaded_model
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
@ -88,10 +95,7 @@ def content_from_data(data_url: str) -> str:
return data.decode(encoding)
elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
return parse_pdf(data)
else:
log.error("Could not extract content from data_url properly.")
@ -105,6 +109,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
pattern = re.compile("^(https?://|file://|data:)")
@ -114,6 +121,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
return interleaved_text_media_as_str(doc.content)

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import statistics
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}
# TODO: decide whether we want to make aggregation functions as a registerable resource
AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median,
}
def aggregate_metrics(
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
) -> Dict[str, Any]:
agg_results = {}
for metric in metrics:
if metric not in AGGREGATION_FUNCTIONS:
raise ValueError(f"Aggregation function {metric} not found")
agg_fn = AGGREGATION_FUNCTIONS[metric]
agg_results[metric] = agg_fn(scoring_results)
return agg_results

View file

@ -8,11 +8,12 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scoring_fns.
Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
@ -44,11 +45,27 @@ class BaseScoringFn(ABC):
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
if params is None:
params = scoring_params
else:
params.aggregation_functions = scoring_params.aggregation_functions
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)
async def score(
self,

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithChildren
class TelemetryDatasetMixin:
"""Mixin class that provides dataset-related functionality for telemetry providers."""
datasetio_api: DatasetIO
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
dataset_id: str,
max_depth: Optional[int] = None,
) -> None:
spans = await self.query_spans(
attribute_filters=attribute_filters,
attributes_to_return=attributes_to_save,
max_depth=max_depth,
)
rows = [
{
"trace_id": span.trace_id,
"span_id": span.span_id,
"parent_span_id": span.parent_span_id,
"name": span.name,
"start_time": span.start_time,
"end_time": span.end_time,
**{attr: span.attributes.get(attr) for attr in attributes_to_save},
}
for span in spans
]
await self.datasetio_api.append_rows(dataset_id=dataset_id, rows=rows)
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
) -> List[Span]:
traces = await self.query_traces(attribute_filters=attribute_filters)
spans = []
for trace in traces:
span_tree = await self.get_span_tree(
span_id=trace.root_span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
def extract_spans(span: SpanWithChildren) -> List[Span]:
result = []
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
):
result.append(
Span(
trace_id=trace.root_span_id,
span_id=span.span_id,
parent_span_id=span.parent_span_id,
name=span.name,
start_time=span.start_time,
end_time=span.end_time,
attributes=span.attributes,
)
)
for child in span.children:
result.extend(extract_spans(child))
return result
spans.extend(extract_spans(span_tree))
return spans

View file

@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List, Optional, Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
class TraceStore(Protocol):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
class SQLiteTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
conditions = [
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op.value]} ?"
for condition in attribute_filters
]
params = [condition.value for condition in attribute_filters]
where_clause = " WHERE " + " AND ".join(conditions)
return where_clause, params
def build_order_clause() -> str:
if not order_by:
return ""
order_clauses = []
for field in order_by:
desc = field.startswith("-")
clean_field = field[1:] if desc else field
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
return " ORDER BY " + ", ".join(order_clauses)
# Build the main query
base_query = """
WITH matching_traces AS (
SELECT DISTINCT t.trace_id
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
{where_clause}
),
filtered_traces AS (
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
FROM matching_traces mt
JOIN traces t ON mt.trace_id = t.trace_id
LEFT JOIN spans s ON t.trace_id = s.trace_id
{order_clause}
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
LIMIT {limit} OFFSET {offset}
"""
where_clause, params = build_where_clause()
query = base_query.format(
where_clause=where_clause,
order_clause=build_order_clause(),
limit=limit,
offset=offset,
)
# Execute query and return results
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [
Trace(
trace_id=row["trace_id"],
root_span_id=row["root_span_id"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
)
for row in rows
]
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attributes_to_return
)
attributes_select = f"json_object({json_object})"
# SQLite CTE query with filtered attributes
query = f"""
WITH RECURSIVE span_tree AS (
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
FROM spans s
WHERE s.span_id = ?
UNION ALL
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
FROM spans s
JOIN span_tree st ON s.parent_span_id = st.span_id
WHERE (? IS NULL OR st.depth < ?)
)
SELECT *
FROM span_tree
ORDER BY depth, start_time
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
rows = await cursor.fetchall()
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = SpanWithChildren(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
name=row["name"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
from datetime import datetime
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel
T = TypeVar("T")
def serialize_value(value: Any) -> Any:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return None
elif isinstance(value, (str, int, float, bool)):
return value
elif isinstance(value, BaseModel):
return value.model_dump()
elif isinstance(value, (list, tuple, set)):
return [serialize_value(item) for item in value]
elif isinstance(value, dict):
return {str(k): serialize_value(v) for k, v in value.items()}
elif isinstance(value, (datetime, UUID)):
return str(value)
else:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable) -> Callable:
from llama_stack.providers.utils.telemetry import tracing
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}"
)
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": str(combined_args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1
finally:
span.set_attribute("chunk_count", count)
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as _e:
raise
if is_async_gen:
return async_gen_wrapper
elif is_async:
return async_wrapper
else:
return sync_wrapper
original_init_subclass = getattr(cls, "__init_subclass__", None)
def __init_subclass__(cls_child, **kwargs): # noqa: N807
if original_init_subclass:
original_init_subclass(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -69,7 +69,7 @@ class TraceContext:
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None):
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
@ -94,6 +94,7 @@ class TraceContext:
)
self.spans.append(span)
return span
def pop_span(self, status: SpanStatus = SpanStatus.OK):
span = self.spans.pop()
@ -203,12 +204,13 @@ class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name
self.attributes = attributes
self.span = None
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
@ -217,11 +219,24 @@ class SpanContextManager:
if context:
context.pop_span()
def set_attribute(self, key: str, value: Any):
if self.span:
if self.span.attributes is None:
self.span.attributes = {}
self.span.attributes[key] = value
async def __aenter__(self):
return self.__enter__()
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
def __call__(self, func: Callable):
@wraps(func)
@ -246,3 +261,11 @@ class SpanContextManager:
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
return context.get_current_span()
return None