Merge branch 'meta-llama:main' into main

This commit is contained in:
Shrinit Goyal 2024-12-12 12:42:45 +05:30 committed by GitHub
commit fced5ec6dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
208 changed files with 7952 additions and 1104 deletions

View file

@ -53,8 +53,6 @@ class ShieldsProtocolPrivate(Protocol):
class MemoryBanksProtocolPrivate(Protocol):
async def list_memory_banks(self) -> List[MemoryBank]: ...
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
@ -63,6 +61,8 @@ class MemoryBanksProtocolPrivate(Protocol):
class DatasetsProtocolPrivate(Protocol):
async def register_dataset(self, dataset: Dataset) -> None: ...
async def unregister_dataset(self, dataset_id: str) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFn]: ...

View file

@ -10,9 +10,7 @@ import logging
import os
import re
import secrets
import shutil
import string
import tempfile
import uuid
from datetime import datetime
from typing import AsyncGenerator, List, Tuple
@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin):
self,
agent_id: str,
agent_config: AgentConfig,
tempdir: str,
inference_api: Inference,
memory_api: Memory,
memory_banks_api: MemoryBanks,
@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin):
):
self.agent_id = agent_id
self.agent_config = agent_config
self.tempdir = tempdir
self.inference_api = inference_api
self.memory_api = memory_api
self.memory_banks_api = memory_banks_api
self.safety_api = safety_api
self.storage = AgentPersistence(agent_id, persistence_store)
self.tempdir = tempfile.mkdtemp()
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def __del__(self):
shutil.rmtree(self.tempdir)
def turn_to_messages(self, turn: Turn) -> List[Message]:
messages = []
@ -144,87 +139,91 @@ class ChatAgent(ShieldRunnerMixin):
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
@tracing.span("create_and_execute_turn")
async def create_and_execute_turn(
self, request: AgentTurnCreateRequest
) -> AsyncGenerator:
assert request.stream is True, "Non-streaming not supported"
with tracing.span("create_and_execute_turn") as span:
span.set_attribute("session_id", request.session_id)
span.set_attribute("agent_id", self.agent_id)
span.set_attribute("request", request.model_dump_json())
assert request.stream is True, "Non-streaming not supported"
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
session_info = await self.storage.get_session_info(request.session_id)
if session_info is None:
raise ValueError(f"Session {request.session_id} not found")
turns = await self.storage.get_session_turns(request.session_id)
turns = await self.storage.get_session_turns(request.session_id)
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages)
messages.extend(request.messages)
turn_id = str(uuid.uuid4())
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
turn_id = str(uuid.uuid4())
span.set_attribute("turn_id", turn_id)
start_time = datetime.now()
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnStartPayload(
turn_id=turn_id,
)
)
)
)
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
steps = []
output_message = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
attachments=request.attachments or [],
sampling_params=self.agent_config.sampling_params,
stream=request.stream,
):
steps.append(event.payload.step_details)
if isinstance(chunk, CompletionMessage):
log.info(
f"{chunk.role.capitalize()}: {chunk.content}",
)
output_message = chunk
continue
yield chunk
assert isinstance(
chunk, AgentTurnResponseStreamChunk
), f"Unexpected type {type(chunk)}"
event = chunk.event
if (
event.payload.event_type
== AgentTurnResponseEventType.step_complete.value
):
steps.append(event.payload.step_details)
assert output_message is not None
yield chunk
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
assert output_message is not None
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=request.messages,
output_message=output_message,
started_at=start_time,
completed_at=datetime.now(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnCompletePayload(
turn=turn,
)
)
)
)
yield chunk
yield chunk
async def run(
self,
@ -273,7 +272,6 @@ class ChatAgent(ShieldRunnerMixin):
yield final_response
@tracing.span("run_shields")
async def run_multiple_shields_wrapper(
self,
turn_id: str,
@ -281,23 +279,46 @@ class ChatAgent(ShieldRunnerMixin):
shields: List[str],
touchpoint: str,
) -> AsyncGenerator:
if len(shields) == 0:
return
with tracing.span("run_shields") as span:
span.set_attribute("input", [m.model_dump_json() for m in messages])
if len(shields) == 0:
span.set_attribute("output", "no shields")
return
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
step_id = str(uuid.uuid4())
try:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.shield_call.value,
step_id=step_id,
metadata=dict(touchpoint=touchpoint),
)
)
)
)
await self.run_multiple_shields(messages, shields)
await self.run_multiple_shields(messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
),
)
)
)
span.set_attribute("output", e.violation.model_dump_json())
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
@ -305,30 +326,12 @@ class ChatAgent(ShieldRunnerMixin):
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=e.violation,
violation=None,
),
)
)
)
yield CompletionMessage(
content=str(e),
stop_reason=StopReason.end_of_turn,
)
yield False
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.shield_call.value,
step_details=ShieldCallStep(
step_id=step_id,
turn_id=turn_id,
violation=None,
),
)
)
)
span.set_attribute("output", "no violations")
async def _run(
self,
@ -356,10 +359,15 @@ class ChatAgent(ShieldRunnerMixin):
# TODO: find older context from the session and either replace it
# or append with a sliding window. this is really a very simplistic implementation
with tracing.span("retrieve_rag_context"):
with tracing.span("retrieve_rag_context") as span:
rag_context, bank_ids = await self._retrieve_context(
session_id, input_messages, attachments
)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute("output", rag_context)
span.set_attribute("bank_ids", bank_ids)
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -396,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = 0
while True:
msg = input_messages[-1]
if len(str(msg)) > 1000:
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
else:
msg_str = str(msg)
log.info(f"{msg_str}")
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
@ -416,7 +419,7 @@ class ChatAgent(ShieldRunnerMixin):
content = ""
stop_reason = None
with tracing.span("inference"):
with tracing.span("inference") as span:
async for chunk in await self.inference_api.chat_completion(
self.agent_config.model,
input_messages,
@ -436,14 +439,13 @@ class ChatAgent(ShieldRunnerMixin):
if isinstance(delta, ToolCallDelta):
if delta.parse_status == ToolCallParseStatus.success:
tool_calls.append(delta.content)
if stream:
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta="",
text_delta="",
tool_call_delta=delta,
)
)
@ -457,7 +459,7 @@ class ChatAgent(ShieldRunnerMixin):
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.inference.value,
step_id=step_id,
model_response_text_delta=event.delta,
text_delta=event.delta,
)
)
)
@ -466,6 +468,13 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute(
"input", [m.model_dump_json() for m in input_messages]
)
span.set_attribute(
"output", f"content: {content} tool_calls: {tool_calls}"
)
stop_reason = stop_reason or StopReason.out_of_tokens
@ -549,7 +558,13 @@ class ChatAgent(ShieldRunnerMixin):
)
)
with tracing.span("tool_execution"):
with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
result_messages = await execute_tool_call_maybe(
self.tools_dict,
[message],
@ -558,6 +573,7 @@ class ChatAgent(ShieldRunnerMixin):
len(result_messages) == 1
), "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(

View file

@ -6,9 +6,13 @@
import json
import logging
import shutil
import tempfile
import uuid
from typing import AsyncGenerator
from termcolor import colored
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
@ -40,10 +44,20 @@ class MetaReferenceAgentsImpl(Agents):
self.memory_banks_api = memory_banks_api
self.in_memory_store = InmemoryKVStoreImpl()
self.tempdir = tempfile.mkdtemp()
async def initialize(self) -> None:
self.persistence_store = await kvstore_impl(self.config.persistence_store)
# check if "bwrap" is available
if not shutil.which("bwrap"):
print(
colored(
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
"yellow",
)
)
async def create_agent(
self,
agent_config: AgentConfig,
@ -82,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents):
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config,
tempdir=self.tempdir,
inference_api=self.inference_api,
safety_api=self.safety_api,
memory_api=self.memory_api,

View file

@ -3,14 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import urlparse
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
@ -97,6 +100,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_impl=dataset_impl,
)
async def unregister_dataset(self, dataset_id: str) -> None:
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,
@ -128,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows),
next_page_token=str(end),
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat(
[dataset_impl.df, new_rows_df], ignore_index=True
)
url = str(dataset_info.dataset_def.url)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
"utf-8"
)
dataset_info.dataset_def.url = URL(
uri=f"data:text/csv;base64,{base64_content}"
)
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -3,12 +3,13 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
class MetaReferenceEvalConfig(BaseModel):

View file

@ -4,7 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional
from llama_models.llama3.api.datatypes import * # noqa: F403
from tqdm import tqdm
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
@ -17,7 +19,6 @@ from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from tqdm import tqdm
from .config import MetaReferenceEvalConfig

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import ConsoleConfig
from .config import ChromaInlineImplConfig
async def get_provider_impl(config: ConsoleConfig, _deps):
from .console import ConsoleTelemetryImpl
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
impl = ConsoleTelemetryImpl(config)
impl = ChromaMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -4,18 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from typing import Any, Dict
from pydantic import BaseModel
class LogFormat(Enum):
TEXT = "text"
JSON = "json"
class ChromaInlineImplConfig(BaseModel):
db_path: str
@json_schema_type
class ConsoleConfig(BaseModel):
log_format: LogFormat = LogFormat.TEXT
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"db_path": "{env.CHROMADB_PATH}"}

View file

@ -27,7 +27,6 @@ from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from llama_stack.providers.utils.telemetry import tracing
from .config import FaissImplConfig
@ -95,7 +94,6 @@ class FaissIndex(EmbeddingIndex):
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
@tracing.span(name="add_chunks")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks):

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import json
from typing import Optional
from typing import List, Optional
from .config import LogFormat
@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry):
if formatted:
print(formatted)
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError()
async def query_traces(
self,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
raise NotImplementedError("Console telemetry does not support trace querying")
async def get_spans(
self,
span_id: str,
attribute_conditions: Optional[List[QueryCondition]] = None,
attribute_keys_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> SpanWithChildren:
raise NotImplementedError("Console telemetry does not support span querying")
COLORS = {

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -113,7 +113,9 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from typing import Any, Dict, Optional
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from .fn_defs.equality import equality
@ -42,8 +42,3 @@ class EqualityScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -5,14 +5,20 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
equality = ScoringFn(
identifier="basic::equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
params=None,
provider_id="basic",
provider_resource_id="equality",
return_type=NumberType(),
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
)

View file

@ -4,9 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
RegexParserScoringFnParams,
ScoringFn,
)
MULTILINGUAL_ANSWER_REGEXES = [
r"Answer\s*:",
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
for x in MULTILINGUAL_ANSWER_REGEXES
],
aggregation_functions=[AggregationFunctionType.accuracy],
),
)

View file

@ -5,7 +5,11 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
ScoringFn,
)
subset_of = ScoringFn(
@ -14,4 +18,7 @@ subset_of = ScoringFn(
return_type=NumberType(),
provider_id="basic",
provider_resource_id="subset-of",
params=BasicScoringFnParams(
aggregation_functions=[AggregationFunctionType.accuracy]
),
)

View file

@ -5,11 +5,11 @@
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.regex_parser_multiple_choice_answer import (
regex_parser_multiple_choice_answer,
@ -60,8 +60,3 @@ class RegexParserScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, Optional
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
from .fn_defs.subset_of import subset_of
@ -36,8 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn):
return {
"score": score,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
return aggregate_accuracy(scoring_results)

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from pydantic import BaseModel
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import BraintrustScoringConfig

View file

@ -16,6 +16,7 @@ import os
from autoevals.llm import Factuality
from autoevals.ragas import AnswerCorrectness
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
@ -85,7 +86,7 @@ class BraintrustScoringImpl(
async def set_api_key(self) -> None:
# api key is in the request headers
if self.config.openai_api_key is None:
if not self.config.openai_api_key:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.openai_api_key:
raise ValueError(
@ -146,7 +147,7 @@ class BraintrustScoringImpl(
await self.score_row(input_row, scoring_fn_id)
for input_row in input_rows
]
aggregation_functions = [AggregationFunctionType.average]
agg_results = aggregate_average(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,

View file

@ -11,3 +11,9 @@ class BraintrustScoringConfig(BaseModel):
default=None,
description="The OpenAI API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"openai_api_key": "${env.OPENAI_API_KEY:}",
}

View file

@ -120,7 +120,9 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
agg_results = await scoring_fn.aggregate(
score_results, scoring_fn_id, scoring_fn_params
)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
llm_as_judge_base = ScoringFn(
@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn(
return_type=NumberType(),
provider_id="llm-as-judge",
provider_resource_id="llm-as-judge-base",
params=LLMAsJudgeScoringFnParams(
judge_model="meta-llama/Llama-3.1-405B-Instruct",
prompt_template="Enter custom LLM as Judge Prompt Template",
),
)

View file

@ -3,13 +3,16 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import re
from typing import Any, Dict, Optional
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
import re
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
@ -85,9 +88,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
"score": judge_rating,
"judge_feedback": content,
}
async def aggregate(
self, scoring_results: List[ScoringResultRow]
) -> Dict[str, Any]:
# TODO: this needs to be config based aggregation, and only useful w/ Jobs API
return {}

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from .config import TelemetryConfig, TelemetrySink
__all__ = ["TelemetryConfig", "TelemetrySink"]
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
from .telemetry import TelemetryAdapter
impl = TelemetryAdapter(config, deps)
await impl.initialize()
return impl

View file

@ -0,0 +1,58 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
class TelemetrySink(str, Enum):
OTEL = "otel"
SQLITE = "sqlite"
CONSOLE = "console"
class TelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
)
sqlite_db_path: str = Field(
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
description="The path to the SQLite database to use for storing traces",
)
@field_validator("sinks", mode="before")
@classmethod
def validate_sinks(cls, v):
if isinstance(v, str):
return [TelemetrySink(sink.strip()) for sink in v.split(",")]
return v
@classmethod
def sample_run_config(
cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db"
) -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/"
+ __distro_dir__
+ "/"
+ db_name
+ "}",
}

View file

@ -0,0 +1,117 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from opentelemetry.sdk.trace import ReadableSpan
from opentelemetry.sdk.trace.export import SpanProcessor
from opentelemetry.trace.status import StatusCode
# Colors for console output
COLORS = {
"reset": "\033[0m",
"bold": "\033[1m",
"dim": "\033[2m",
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
}
class ConsoleSpanProcessor(SpanProcessor):
def __init__(self, print_attributes: bool = False):
self.print_attributes = print_attributes
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
print(
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[START]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
def on_end(self, span: ReadableSpan) -> None:
if span.attributes and span.attributes.get("__autotraced__"):
return
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
span_context = (
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
f"{COLORS['magenta']}[END]{COLORS['reset']} "
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
)
if span.status.status_code == StatusCode.ERROR:
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
elif span.status.status_code != StatusCode.UNSET:
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
duration_ms = (span.end_time - span.start_time) / 1e6
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
print(span_context)
if self.print_attributes and span.attributes:
for key, value in span.attributes.items():
if key.startswith("__"):
continue
str_value = str(value)
if len(str_value) > 1000:
str_value = str_value[:997] + "..."
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
for event in span.events:
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
"%H:%M:%S.%f"
)[:-3]
severity = event.attributes.get("severity", "info")
message = event.attributes.get("message", event.name)
if isinstance(message, (dict, list)):
message = json.dumps(message, indent=2)
severity_colors = {
"error": f"{COLORS['bold']}{COLORS['red']}",
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
"info": COLORS["white"],
"debug": COLORS["dim"],
}
msg_color = severity_colors.get(severity, COLORS["white"])
print(
f" {event_time} "
f"{msg_color}[{severity.upper()}] "
f"{message}{COLORS['reset']}"
)
if event.attributes:
for key, value in event.attributes.items():
if key.startswith("__") or key in ["message", "severity"]:
continue
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
def shutdown(self) -> None:
"""Shutdown the processor."""
pass
def force_flush(self, timeout_millis: float = None) -> bool:
"""Force flush any pending spans."""
return True

View file

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import os
import sqlite3
from datetime import datetime
from opentelemetry.sdk.trace import SpanProcessor
from opentelemetry.trace import Span
class SQLiteSpanProcessor(SpanProcessor):
def __init__(self, conn_string):
"""Initialize the SQLite span processor with a connection string."""
self.conn_string = conn_string
self.conn = None
self.setup_database()
def _get_connection(self) -> sqlite3.Connection:
"""Get the database connection."""
if self.conn is None:
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
return self.conn
def setup_database(self):
"""Create the necessary tables if they don't exist."""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
conn = self._get_connection()
cursor = conn.cursor()
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS traces (
trace_id TEXT PRIMARY KEY,
service_name TEXT,
root_span_id TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS spans (
span_id TEXT PRIMARY KEY,
trace_id TEXT REFERENCES traces(trace_id),
parent_span_id TEXT,
name TEXT,
start_time TIMESTAMP,
end_time TIMESTAMP,
attributes TEXT,
status TEXT,
kind TEXT
)
"""
)
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS span_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
span_id TEXT REFERENCES spans(span_id),
name TEXT,
timestamp TIMESTAMP,
attributes TEXT
)
"""
)
cursor.execute(
"""
CREATE INDEX IF NOT EXISTS idx_traces_created_at
ON traces(created_at)
"""
)
conn.commit()
cursor.close()
def on_start(self, span: Span, parent_context=None):
"""Called when a span starts."""
pass
def on_end(self, span: Span):
"""Called when a span ends. Export the span data to SQLite."""
try:
conn = self._get_connection()
cursor = conn.cursor()
trace_id = format(span.get_span_context().trace_id, "032x")
span_id = format(span.get_span_context().span_id, "016x")
service_name = span.resource.attributes.get("service.name", "unknown")
parent_span_id = None
parent_context = span.parent
if parent_context:
parent_span_id = format(parent_context.span_id, "016x")
# Insert into traces
cursor.execute(
"""
INSERT INTO traces (
trace_id, service_name, root_span_id, start_time, end_time
) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(trace_id) DO UPDATE SET
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
start_time = MIN(excluded.start_time, start_time),
end_time = MAX(excluded.end_time, end_time)
""",
(
trace_id,
service_name,
(span_id if not parent_span_id else None),
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
),
)
# Insert into spans
cursor.execute(
"""
INSERT INTO spans (
span_id, trace_id, parent_span_id, name,
start_time, end_time, attributes, status,
kind
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
span_id,
trace_id,
parent_span_id,
span.name,
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
json.dumps(dict(span.attributes)),
span.status.status_code.name,
span.kind.name,
),
)
for event in span.events:
cursor.execute(
"""
INSERT INTO span_events (
span_id, name, timestamp, attributes
) VALUES (?, ?, ?, ?)
""",
(
span_id,
event.name,
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
json.dumps(dict(event.attributes)),
),
)
conn.commit()
cursor.close()
except Exception as e:
print(f"Error exporting span to SQLite: {e}")
def shutdown(self):
"""Cleanup any resources."""
if self.conn:
self.conn.close()
self.conn = None
def force_flush(self, timeout_millis=30000):
"""Force export of spans."""
pass

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import threading
from typing import Any, Dict, List, Optional
from opentelemetry import metrics, trace
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
@ -16,10 +17,21 @@ from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.semconv.resource import ResourceAttributes
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
ConsoleSpanProcessor,
)
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
SQLiteSpanProcessor,
)
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
from llama_stack.apis.telemetry import * # noqa: F403
from .config import OpenTelemetryConfig
from llama_stack.distribution.datatypes import Api
from .config import TelemetryConfig, TelemetrySink
_GLOBAL_STORAGE = {
"active_spans": {},
@ -45,9 +57,10 @@ def is_tracing_enabled(tracer):
return span.is_recording()
class OpenTelemetryAdapter(Telemetry):
def __init__(self, config: OpenTelemetryConfig):
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
self.config = config
self.datasetio_api = deps[Api.datasetio]
resource = Resource.create(
{
@ -57,22 +70,29 @@ class OpenTelemetryAdapter(Telemetry):
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
# Set up metrics
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
if TelemetrySink.OTEL in self.config.sinks:
otlp_exporter = OTLPSpanExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
span_processor = BatchSpanProcessor(otlp_exporter)
trace.get_tracer_provider().add_span_processor(span_processor)
metric_reader = PeriodicExportingMetricReader(
OTLPMetricExporter(
endpoint=self.config.otel_endpoint,
)
)
metric_provider = MeterProvider(
resource=resource, metric_readers=[metric_reader]
)
metrics.set_meter_provider(metric_provider)
self.meter = metrics.get_meter(__name__)
if TelemetrySink.SQLITE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(
SQLiteSpanProcessor(self.config.sqlite_db_path)
)
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
if TelemetrySink.CONSOLE in self.config.sinks:
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
self._lock = _global_lock
async def initialize(self) -> None:
@ -83,15 +103,17 @@ class OpenTelemetryAdapter(Telemetry):
trace.get_tracer_provider().shutdown()
metrics.get_meter_provider().shutdown()
async def log_event(self, event: Event) -> None:
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
if isinstance(event, UnstructuredLogEvent):
self._log_unstructured(event)
self._log_unstructured(event, ttl_seconds)
elif isinstance(event, MetricEvent):
self._log_metric(event)
elif isinstance(event, StructuredLogEvent):
self._log_structured(event)
self._log_structured(event, ttl_seconds)
else:
raise ValueError(f"Unknown event type: {event}")
def _log_unstructured(self, event: UnstructuredLogEvent) -> None:
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
# Use global storage instead of instance storage
span_id = string_to_span_id(event.span_id)
@ -104,6 +126,7 @@ class OpenTelemetryAdapter(Telemetry):
attributes={
"message": event.message,
"severity": event.severity.value,
"__ttl__": ttl_seconds,
**event.attributes,
},
timestamp=timestamp_ns,
@ -154,11 +177,14 @@ class OpenTelemetryAdapter(Telemetry):
)
return _GLOBAL_STORAGE["up_down_counters"][name]
def _log_structured(self, event: StructuredLogEvent) -> None:
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
with self._lock:
span_id = string_to_span_id(event.span_id)
trace_id = string_to_trace_id(event.trace_id)
tracer = trace.get_tracer(__name__)
if event.attributes is None:
event.attributes = {}
event.attributes["__ttl__"] = ttl_seconds
if isinstance(event.payload, SpanStartPayload):
# Check if span already exists to prevent duplicates
@ -170,7 +196,6 @@ class OpenTelemetryAdapter(Telemetry):
parent_span_id = string_to_span_id(event.payload.parent_span_id)
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
# Create a new trace context with the trace_id
context = trace.Context(trace_id=trace_id)
if parent_span:
context = trace.set_span_in_context(parent_span, context)
@ -179,14 +204,9 @@ class OpenTelemetryAdapter(Telemetry):
name=event.payload.name,
context=context,
attributes=event.attributes or {},
start_time=int(event.timestamp.timestamp() * 1e9),
)
_GLOBAL_STORAGE["active_spans"][span_id] = span
# Set as current span using context manager
with trace.use_span(span, end_on_exit=False):
pass # Let the span continue beyond this block
elif isinstance(event.payload, SpanEndPayload):
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
if span:
@ -199,10 +219,33 @@ class OpenTelemetryAdapter(Telemetry):
else trace.Status(status_code=trace.StatusCode.ERROR)
)
span.set_status(status)
span.end(end_time=int(event.timestamp.timestamp() * 1e9))
# Remove from active spans
span.end()
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
else:
raise ValueError(f"Unknown structured log event: {event}")
async def get_trace(self, trace_id: str) -> Trace:
raise NotImplementedError("Trace retrieval not implemented yet")
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
return await self.trace_store.query_traces(
attribute_filters=attribute_filters,
limit=limit,
offset=offset,
order_by=order_by,
)
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
return await self.trace_store.get_span_tree(
span_id=span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)

View file

@ -61,6 +61,17 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(

View file

@ -53,9 +53,16 @@ def available_providers() -> List[ProviderSpec]:
adapter_type="chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
),
),
InlineProviderSpec(
api=Api.memory,
provider_type="inline::chromadb",
pip_packages=EMBEDDING_DEPS + ["chromadb"],
module="llama_stack.providers.inline.memory.chroma",
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
),
remote_provider_spec(
Api.memory,
AdapterSpec(

View file

@ -14,9 +14,13 @@ def available_providers() -> List[ProviderSpec]:
InlineProviderSpec(
api=Api.telemetry,
provider_type="inline::meta-reference",
pip_packages=[],
module="llama_stack.providers.inline.meta_reference.telemetry",
config_class="llama_stack.providers.inline.meta_reference.telemetry.ConsoleConfig",
pip_packages=[
"opentelemetry-sdk",
"opentelemetry-exporter-otlp-proto-http",
],
api_dependencies=[Api.datasetio],
module="llama_stack.providers.inline.telemetry.meta_reference",
config_class="llama_stack.providers.inline.telemetry.meta_reference.config.TelemetryConfig",
),
remote_provider_spec(
api=Api.telemetry,
@ -27,18 +31,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.telemetry.sample.SampleConfig",
),
),
remote_provider_spec(
api=Api.telemetry,
adapter=AdapterSpec(
adapter_type="opentelemetry-jaeger",
pip_packages=[
"opentelemetry-api",
"opentelemetry-sdk",
"opentelemetry-exporter-jaeger",
"opentelemetry-semantic-conventions",
],
module="llama_stack.providers.remote.telemetry.opentelemetry",
config_class="llama_stack.providers.remote.telemetry.opentelemetry.OpenTelemetryConfig",
),
),
]

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import * # noqa: F403
@ -64,6 +64,11 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
self,
dataset_id: str,
@ -95,3 +100,22 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
total_count=len(rows),
next_page_token=str(end),
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)
# Concatenate the new rows with existing dataset
updated_dataset = hf_datasets.concatenate_datasets(
[loaded_dataset, new_dataset]
)
if dataset_def.metadata.get("path", None):
updated_dataset.push_to_hub(dataset_def.metadata["path"])
else:
raise NotImplementedError(
"Uploading to URL-based datasets is not supported yet"
)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import CerebrasImplConfig
async def get_adapter_impl(config: CerebrasImplConfig, _deps):
from .cerebras import CerebrasInferenceAdapter
assert isinstance(
config, CerebrasImplConfig
), f"Unexpected config type: {type(config)}"
impl = CerebrasInferenceAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import AsyncGenerator
from cerebras.cloud.sdk import AsyncCerebras
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_models.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
build_model_alias,
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
)
from .config import CerebrasImplConfig
model_aliases = [
build_model_alias(
"llama3.1-8b",
CoreModelId.llama3_1_8b_instruct.value,
),
build_model_alias(
"llama3.1-70b",
CoreModelId.llama3_1_70b_instruct.value,
),
]
class CerebrasInferenceAdapter(ModelRegistryHelper, Inference):
def __init__(self, config: CerebrasImplConfig) -> None:
ModelRegistryHelper.__init__(
self,
model_aliases=model_aliases,
)
self.config = config
self.formatter = ChatFormat(Tokenizer.get_instance())
self.client = AsyncCerebras(
base_url=self.config.base_url, api_key=self.config.api_key
)
async def initialize(self) -> None:
return
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(
request,
)
else:
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
r = await self.client.completions.create(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: CompletionRequest
) -> CompletionResponse:
params = self._get_params(request)
r = await self.client.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: CompletionRequest
) -> AsyncGenerator:
params = self._get_params(request)
stream = await self.client.completions.create(**params)
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
yield chunk
def _get_params(
self, request: Union[ChatCompletionRequest, CompletionRequest]
) -> dict:
if request.sampling_params and request.sampling_params.top_k:
raise ValueError("`top_k` not supported by Cerebras")
prompt = ""
if type(request) == ChatCompletionRequest:
prompt = chat_completion_request_to_prompt(
request, self.get_llama_model(request.model), self.formatter
)
elif type(request) == CompletionRequest:
prompt = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
return {
"model": request.model,
"prompt": prompt,
"stream": request.stream,
**get_sampling_options(request.sampling_params),
}
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()

View file

@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
DEFAULT_BASE_URL = "https://api.cerebras.ai"
@json_schema_type
class CerebrasImplConfig(BaseModel):
base_url: str = Field(
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
description="Base URL for the Cerebras API",
)
api_key: Optional[str] = Field(
default=os.environ.get("CEREBRAS_API_KEY"),
description="Cerebras API Key",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"base_url": DEFAULT_BASE_URL,
"api_key": "${env.CEREBRAS_API_KEY}",
}

View file

@ -9,6 +9,7 @@ from typing import AsyncIterator, List, Optional, Union
from llama_models.datatypes import SamplingParams
from llama_models.llama3.api.datatypes import (
ImageMedia,
InterleavedTextMedia,
Message,
ToolChoice,
@ -22,6 +23,7 @@ from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
@ -37,8 +39,11 @@ from llama_stack.providers.utils.inference.model_registry import (
from . import NVIDIAConfig
from .openai_utils import (
convert_chat_completion_request,
convert_completion_request,
convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream,
convert_openai_completion_choice,
convert_openai_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
@ -115,7 +120,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
timeout=self._config.timeout,
)
def completion(
async def completion(
self,
model_id: str,
content: InterleavedTextMedia,
@ -124,7 +129,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
raise NotImplementedError()
if isinstance(content, ImageMedia) or (
isinstance(content, list)
and any(isinstance(c, ImageMedia) for c in content)
):
raise NotImplementedError("ImageMedia is not supported")
await check_health(self._config) # this raises errors
request = convert_completion_request(
request=CompletionRequest(
model=self.get_provider_model_id(model_id),
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
),
n=1,
)
try:
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(
f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}"
) from e
if stream:
return convert_openai_completion_stream(response)
else:
# we pass n=1 to get only one completion
return convert_openai_completion_choice(response.choices[0])
async def embeddings(
self,

View file

@ -17,7 +17,6 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
)
from openai import AsyncStream
from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -31,10 +30,11 @@ from openai.types.chat.chat_completion import (
Choice as OpenAIChoice,
ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs
)
from openai.types.chat.chat_completion_message_tool_call_param import (
Function as OpenAIFunction,
)
from openai.types.completion import Completion as OpenAICompletion
from openai.types.completion_choice import Logprobs as OpenAICompletionLogprobs
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -42,6 +42,9 @@ from llama_stack.apis.inference import (
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk,
CompletionRequest,
CompletionResponse,
CompletionResponseStreamChunk,
JsonSchemaResponseFormat,
Message,
SystemMessage,
@ -579,3 +582,165 @@ async def convert_openai_chat_completion_stream(
stop_reason=stop_reason,
)
)
def convert_completion_request(
request: CompletionRequest,
n: int = 1,
) -> dict:
"""
Convert a ChatCompletionRequest to an OpenAI API-compatible dictionary.
"""
# model -> model
# prompt -> prompt
# sampling_params TODO(mattf): review strategy
# strategy=greedy -> nvext.top_k = -1, temperature = temperature
# strategy=top_p -> nvext.top_k = -1, top_p = top_p
# strategy=top_k -> nvext.top_k = top_k
# temperature -> temperature
# top_p -> top_p
# top_k -> nvext.top_k
# max_tokens -> max_tokens
# repetition_penalty -> nvext.repetition_penalty
# response_format -> nvext.guided_json
# stream -> stream
# logprobs.top_k -> logprobs
nvext = {}
payload: Dict[str, Any] = dict(
model=request.model,
prompt=request.content,
stream=request.stream,
extra_body=dict(nvext=nvext),
extra_headers={
b"User-Agent": b"llama-stack: nvidia-inference-adapter",
},
n=n,
)
if request.response_format:
# this is not openai compliant, it is a nim extension
nvext.update(guided_json=request.response_format.json_schema)
if request.logprobs:
payload.update(logprobs=request.logprobs.top_k)
if request.sampling_params:
nvext.update(repetition_penalty=request.sampling_params.repetition_penalty)
if request.sampling_params.max_tokens:
payload.update(max_tokens=request.sampling_params.max_tokens)
if request.sampling_params.strategy == "top_p":
nvext.update(top_k=-1)
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if (
request.sampling_params.top_k != -1
and request.sampling_params.top_k < 1
):
warnings.warn("top_k must be -1 or >= 1")
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)
payload.update(temperature=request.sampling_params.temperature)
return payload
def _convert_openai_completion_logprobs(
logprobs: Optional[OpenAICompletionLogprobs],
) -> Optional[List[TokenLogProbs]]:
"""
Convert an OpenAI CompletionLogprobs into a list of TokenLogProbs.
OpenAI CompletionLogprobs:
text_offset: Optional[List[int]]
token_logprobs: Optional[List[float]]
tokens: Optional[List[str]]
top_logprobs: Optional[List[Dict[str, float]]]
->
TokenLogProbs:
logprobs_by_token: Dict[str, float]
- token, logprob
"""
if not logprobs:
return None
return [
TokenLogProbs(logprobs_by_token=logprobs) for logprobs in logprobs.top_logprobs
]
def convert_openai_completion_choice(
choice: OpenAIChoice,
) -> CompletionResponse:
"""
Convert an OpenAI Completion Choice into a CompletionResponse.
OpenAI Completion Choice:
text: str
finish_reason: str
logprobs: Optional[ChoiceLogprobs]
->
CompletionResponse:
completion_message: CompletionMessage
logprobs: Optional[List[TokenLogProbs]]
CompletionMessage:
role: Literal["assistant"]
content: str | ImageMedia | List[str | ImageMedia]
stop_reason: StopReason
tool_calls: List[ToolCall]
class StopReason(Enum):
end_of_turn = "end_of_turn"
end_of_message = "end_of_message"
out_of_tokens = "out_of_tokens"
"""
return CompletionResponse(
content=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)
async def convert_openai_completion_stream(
stream: AsyncStream[OpenAICompletion],
) -> AsyncGenerator[CompletionResponse, None]:
"""
Convert a stream of OpenAI Completions into a stream
of ChatCompletionResponseStreamChunks.
OpenAI Completion:
id: str
choices: List[OpenAICompletionChoice]
created: int
model: str
system_fingerprint: Optional[str]
usage: Optional[OpenAICompletionUsage]
OpenAI CompletionChoice:
finish_reason: str
index: int
logprobs: Optional[OpenAILogprobs]
text: str
->
CompletionResponseStreamChunk:
delta: str
stop_reason: Optional[StopReason]
logprobs: Optional[List[TokenLogProbs]]
"""
async for chunk in stream:
choice = chunk.choices[0]
yield CompletionResponseStreamChunk(
delta=choice.text,
stop_reason=_convert_openai_finish_reason(choice.finish_reason),
logprobs=_convert_openai_completion_logprobs(choice.logprobs),
)

View file

@ -180,7 +180,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
r = await self.client.generate(**params)
assert isinstance(r, dict)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
@ -270,7 +269,6 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
r = await self.client.chat(**params)
else:
r = await self.client.generate(**params)
assert isinstance(r, dict)
if "message" in r:
choice = OpenAICompatCompletionChoice(

View file

@ -100,6 +100,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
response_format=response_format,
)
if stream:
return self._stream_chat_completion(request, self.client)
@ -180,6 +181,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.formatter,
)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
input_dict["extra_body"] = {
"guided_json": request.response_format.json_schema
}
elif fmt.type == ResponseFormatType.grammar.value:
raise NotImplementedError("Grammar response format not supported yet")
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": request.model,
**input_dict,

View file

@ -4,12 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.distribution.datatypes import RemoteProviderConfig
from .config import ChromaRemoteImplConfig
async def get_adapter_impl(config: RemoteProviderConfig, _deps):
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config.url)
impl = ChromaMemoryAdapter(config)
await impl.initialize()
return impl

View file

@ -3,7 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import logging
from typing import List
@ -12,21 +12,31 @@ from urllib.parse import urlparse
import chromadb
from numpy.typing import NDArray
from pydantic import parse_obj_as
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)
from .config import ChromaRemoteImplConfig
log = logging.getLogger(__name__)
ChromaClientType = Union[chromadb.AsyncHttpClient, chromadb.PersistentClient]
# this is a helper to allow us to use async and non-async chroma clients interchangeably
async def maybe_await(result):
if asyncio.iscoroutine(result):
return await result
return result
class ChromaIndex(EmbeddingIndex):
def __init__(self, client: chromadb.AsyncHttpClient, collection):
def __init__(self, client: ChromaClientType, collection):
self.client = client
self.collection = collection
@ -35,19 +45,23 @@ class ChromaIndex(EmbeddingIndex):
embeddings
), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
await self.collection.add(
documents=[chunk.json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
await maybe_await(
self.collection.add(
documents=[chunk.model_dump_json() for chunk in chunks],
embeddings=embeddings,
ids=[f"{c.document_id}:chunk-{i}" for i, c in enumerate(chunks)],
)
)
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
results = await self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
results = await maybe_await(
self.collection.query(
query_embeddings=[embedding.tolist()],
n_results=k,
include=["documents", "distances"],
)
)
distances = results["distances"][0]
documents = results["documents"][0]
@ -68,31 +82,33 @@ class ChromaIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self):
await self.client.delete_collection(self.collection.name)
await maybe_await(self.client.delete_collection(self.collection.name))
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, url: str) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {url}")
url = url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.host = parsed.hostname
self.port = parsed.port
def __init__(
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config
self.client = None
self.cache = {}
async def initialize(self) -> None:
try:
log.info(f"Connecting to Chroma server at: {self.host}:{self.port}")
self.client = await chromadb.AsyncHttpClient(host=self.host, port=self.port)
except Exception as e:
log.exception("Could not connect to Chroma server")
raise RuntimeError("Could not connect to Chroma server") from e
if isinstance(self.config, ChromaRemoteImplConfig):
log.info(f"Connecting to Chroma server at: {self.config.url}")
url = self.config.url.rstrip("/")
parsed = urlparse(url)
if parsed.path and parsed.path != "/":
raise ValueError("URL should not contain a path")
self.client = await chromadb.AsyncHttpClient(
host=parsed.hostname, port=parsed.port
)
else:
log.info(f"Connecting to Chroma local db at: {self.config.db_path}")
self.client = chromadb.PersistentClient(path=self.config.db_path)
async def shutdown(self) -> None:
pass
@ -105,33 +121,17 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
collection = await self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
collection = await maybe_await(
self.client.get_or_create_collection(
name=memory_bank.identifier,
metadata={"bank": memory_bank.model_dump_json()},
)
)
bank_index = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection)
)
self.cache[memory_bank.identifier] = bank_index
async def list_memory_banks(self) -> List[MemoryBank]:
collections = await self.client.list_collections()
for collection in collections:
try:
data = json.loads(collection.metadata["bank"])
bank = parse_obj_as(VectorMemoryBank, data)
except Exception:
log.exception(f"Failed to parse bank: {collection.metadata}")
continue
index = BankWithIndex(
bank=bank,
index=ChromaIndex(self.client, collection),
)
self.cache[bank.identifier] = index
return [i.bank for i in self.cache.values()]
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
@ -163,7 +163,7 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
bank = await self.memory_bank_store.get_memory_bank(bank_id)
if not bank:
raise ValueError(f"Bank {bank_id} not found in Llama Stack")
collection = await self.client.get_collection(bank_id)
collection = await maybe_await(self.client.get_collection(bank_id))
if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))

View file

@ -0,0 +1,17 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class ChromaRemoteImplConfig(BaseModel):
url: str
@classmethod
def sample_config(cls) -> Dict[str, Any]:
return {"url": "{env.CHROMADB_URL}"}

View file

@ -185,17 +185,6 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id]
async def list_memory_banks(self) -> List[MemoryBank]:
banks = load_models(self.cursor, VectorMemoryBank)
for bank in banks:
if bank.identifier not in self.cache:
index = BankWithIndex(
bank=bank,
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[bank.identifier] = index
return banks
async def insert_documents(
self,
bank_id: str,

View file

@ -127,11 +127,6 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
# Qdrant doesn't have collection level metadata to store the bank properties
# So we only return from the cache value
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]

View file

@ -14,7 +14,7 @@ class SampleMemoryImpl(Memory):
def __init__(self, config: SampleConfig):
self.config = config
async def register_memory_bank(self, memory_bank: MemoryBankDef) -> None:
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
# these are the memory banks the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -141,13 +141,6 @@ class WeaviateMemoryAdapter(
)
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]:
# TODO: right now the Llama Stack is the source of truth for these banks. That is
# not ideal. It should be Weaviate which is the source of truth. Unfortunately,
# list() happens at Stack startup when the Weaviate client (credentials) is not
# yet available. We need to figure out a way to make this work.
return [i.bank for i in self.cache.values()]
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache:
return self.cache[bank_id]

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import OpenTelemetryConfig
async def get_adapter_impl(config: OpenTelemetryConfig, _deps):
from .opentelemetry import OpenTelemetryAdapter
impl = OpenTelemetryAdapter(config)
await impl.initialize()
return impl

View file

@ -1,27 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel, Field
class OpenTelemetryConfig(BaseModel):
otel_endpoint: str = Field(
default="http://localhost:4318/v1/traces",
description="The OpenTelemetry collector endpoint URL",
)
service_name: str = Field(
default="llama-stack",
description="The service name to use for telemetry",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"otel_endpoint": "${env.OTEL_ENDPOINT:http://localhost:4318/v1/traces}",
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
}

View file

@ -81,6 +81,18 @@ class TestDatasetIO:
assert len(response) == 1
assert response[0].identifier == "test_dataset"
with pytest.raises(Exception) as exc_info:
# unregister a dataset that does not exist
await datasets_impl.unregister_dataset("test_dataset2")
await datasets_impl.unregister_dataset("test_dataset")
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
with pytest.raises(Exception) as exc_info:
await datasets_impl.unregister_dataset("test_dataset")
@pytest.mark.asyncio
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack

View file

@ -80,6 +80,13 @@ def pytest_addoption(parser):
help="Specify the inference model to use for testing",
)
parser.addoption(
"--judge-model",
action="store",
default="meta-llama/Llama-3.1-8B-Instruct",
help="Specify the judge model to use for testing",
)
def pytest_generate_tests(metafunc):
if "eval_stack" in metafunc.fixturenames:

View file

@ -7,7 +7,7 @@
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.distribution.datatypes import Api, ModelInput, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture
@ -35,7 +35,7 @@ EVAL_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def eval_stack(request):
async def eval_stack(request, inference_model, judge_model):
fixture_dict = request.param
providers = {}
@ -66,6 +66,13 @@ async def eval_stack(request):
],
providers,
provider_data,
models=[
ModelInput(model_id=model)
for model in [
inference_model,
judge_model,
]
],
)
return test_stack.impls

View file

@ -38,7 +38,7 @@ class Testeval:
assert isinstance(response, list)
@pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack):
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
@ -46,11 +46,7 @@ class Testeval:
eval_stack[Api.datasets],
eval_stack[Api.models],
)
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
@ -77,12 +73,12 @@ class Testeval:
scoring_functions=scoring_functions,
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
model=inference_model,
sampling_params=SamplingParams(),
),
scoring_params={
"meta-reference::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-8B-Instruct",
judge_model=judge_model,
prompt_template=JUDGE_PROMPT,
judge_score_regexes=[
r"Total rating: (\d+)",
@ -97,18 +93,14 @@ class Testeval:
assert "basic::equality" in response.scores
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack):
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
@ -127,7 +119,7 @@ class Testeval:
task_id=task_id,
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
model=inference_model,
sampling_params=SamplingParams(),
),
),
@ -142,18 +134,14 @@ class Testeval:
assert "basic::subset_of" in eval_response.scores
@pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack):
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
eval_impl, eval_tasks_impl, datasets_impl, models_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
response = await datasets_impl.list_datasets()
assert len(response) > 0
if response[0].provider_id != "huggingface":
@ -192,7 +180,7 @@ class Testeval:
task_id=benchmark_id,
task_config=BenchmarkEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
model=inference_model,
sampling_params=SamplingParams(),
),
num_examples=3,

View file

@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.meta_reference import (
)
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
@ -64,6 +65,21 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
)
@pytest.fixture(scope="session")
def inference_cerebras() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="cerebras",
provider_type="remote::cerebras",
config=CerebrasImplConfig(
api_key=get_env_or_fail("CEREBRAS_API_KEY"),
).model_dump(),
)
],
)
@pytest.fixture(scope="session")
def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
@ -206,6 +222,7 @@ INFERENCE_FIXTURES = [
"vllm_remote",
"remote",
"bedrock",
"cerebras",
"nvidia",
"tgi",
]

View file

@ -94,6 +94,8 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::cerebras",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -128,9 +130,7 @@ class TestInference:
@pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust")
async def test_completions_structured_output(
self, inference_model, inference_stack
):
async def test_completion_structured_output(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
@ -139,6 +139,9 @@ class TestInference:
"remote::tgi",
"remote::together",
"remote::fireworks",
"remote::nvidia",
"remote::vllm",
"remote::cerebras",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
@ -198,6 +201,7 @@ class TestInference:
"remote::fireworks",
"remote::tgi",
"remote::together",
"remote::vllm",
"remote::nvidia",
):
pytest.skip("Other inference providers don't support structured output yet")
@ -211,7 +215,15 @@ class TestInference:
response = await inference_impl.chat_completion(
model_id=inference_model,
messages=[
SystemMessage(content="You are a helpful assistant."),
# we include context about Michael Jordan in the prompt so that the test is
# focused on the funtionality of the model and not on the information embedded
# in the model. Llama 3.2 3B Instruct tends to think MJ played for 14 seasons.
SystemMessage(
content=(
"You are a helpful assistant.\n\n"
"Michael Jordan was born in 1963. He played basketball for the Chicago Bulls for 15 seasons."
)
),
UserMessage(content="Please give me information about Michael Jordan."),
],
stream=False,

View file

@ -10,8 +10,10 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider, RemoteProviderConfig
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.chroma import ChromaRemoteImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
from llama_stack.providers.tests.resolver import construct_stack_for_test
@ -79,15 +81,21 @@ def memory_weaviate() -> ProviderFixture:
@pytest.fixture(scope="session")
def memory_chroma() -> ProviderFixture:
url = os.getenv("CHROMA_URL")
if url:
config = ChromaRemoteImplConfig(url=url)
provider_type = "remote::chromadb"
else:
if not os.getenv("CHROMA_DB_PATH"):
raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set")
config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH"))
provider_type = "inline::chromadb"
return ProviderFixture(
providers=[
Provider(
provider_id="chroma",
provider_type="remote::chromadb",
config=RemoteProviderConfig(
host=get_env_or_fail("CHROMA_HOST"),
port=get_env_or_fail("CHROMA_PORT"),
).model_dump(),
provider_type=provider_type,
config=config.model_dump(),
)
]
)

Binary file not shown.

View file

@ -0,0 +1,76 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import mimetypes
import os
from pathlib import Path
import pytest
from llama_stack.apis.memory.memory import MemoryBankDocument, URL
from llama_stack.providers.utils.memory.vector_store import content_from_doc
DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf"
def read_file(file_path: str) -> bytes:
with open(file_path, "rb") as file:
return file.read()
def data_url_from_file(file_path: str) -> str:
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
class TestVectorStore:
@pytest.mark.asyncio
async def test_returns_content_from_pdf_data_uri(self):
data_uri = data_url_from_file(DUMMY_PDF_PATH)
doc = MemoryBankDocument(
document_id="dummy",
content=data_uri,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=url,
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"
@pytest.mark.asyncio
async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
doc = MemoryBankDocument(
document_id="dummy",
content=URL(
uri=url,
),
mime_type="application/pdf",
metadata={},
)
content = await content_from_doc(doc)
assert content == "Dummy PDF file"

View file

@ -47,6 +47,7 @@ def pytest_configure(config):
for fixture_name in [
"basic_scoring_together_inference",
"braintrust_scoring_together_inference",
"llm_as_judge_scoring_together_inference",
]:
config.addinivalue_line(
"markers",
@ -61,9 +62,23 @@ def pytest_addoption(parser):
default="meta-llama/Llama-3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
parser.addoption(
"--judge-model",
action="store",
default="meta-llama/Llama-3.1-8B-Instruct",
help="Specify the judge model to use for testing",
)
def pytest_generate_tests(metafunc):
judge_model = metafunc.config.getoption("--judge-model")
if "judge_model" in metafunc.fixturenames:
metafunc.parametrize(
"judge_model",
[pytest.param(judge_model, id="")],
indirect=True,
)
if "scoring_stack" in metafunc.fixturenames:
available_fixtures = {
"scoring": SCORING_FIXTURES,

View file

@ -21,6 +21,13 @@ def scoring_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def judge_model(request):
if hasattr(request, "param"):
return request.param
return request.config.getoption("--judge-model", None)
@pytest.fixture(scope="session")
def scoring_basic() -> ProviderFixture:
return ProviderFixture(
@ -66,7 +73,7 @@ SCORING_FIXTURES = ["basic", "remote", "braintrust", "llm_as_judge"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request, inference_model):
async def scoring_stack(request, inference_model, judge_model):
fixture_dict = request.param
providers = {}
@ -85,8 +92,7 @@ async def scoring_stack(request, inference_model):
ModelInput(model_id=model)
for model in [
inference_model,
"Llama3.1-405B-Instruct",
"Llama3.1-8B-Instruct",
judge_model,
]
],
)

View file

@ -7,7 +7,12 @@
import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.scoring_functions import (
AggregationFunctionType,
BasicScoringFnParams,
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
@ -18,6 +23,11 @@ from llama_stack.providers.tests.datasetio.test_datasetio import register_datase
# -v -s --tb=short --disable-warnings
@pytest.fixture
def sample_judge_prompt_template():
return "Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9."
class TestScoring:
@pytest.mark.asyncio
async def test_scoring_functions_list(self, scoring_stack):
@ -54,12 +64,6 @@ class TestScoring:
response = await datasets_impl.list_datasets()
assert len(response) == 1
for model_id in ["Llama3.2-3B-Instruct", "Llama3.1-8B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
@ -92,7 +96,9 @@ class TestScoring:
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack):
async def test_scoring_score_with_params_llm_as_judge(
self, scoring_stack, sample_judge_prompt_template, judge_model
):
(
scoring_impl,
scoring_functions_impl,
@ -110,12 +116,6 @@ class TestScoring:
response = await datasets_impl.list_datasets()
assert len(response) == 1
for model_id in ["Llama3.1-405B-Instruct"]:
await models_impl.register_model(
model_id=model_id,
provider_id="",
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
if provider_id == "braintrust" or provider_id == "basic":
@ -129,10 +129,11 @@ class TestScoring:
assert len(rows.rows) == 3
scoring_functions = {
"llm-as-judge::llm_as_judge_base": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
"llm-as-judge::base": LLMAsJudgeScoringFnParams(
judge_model=judge_model,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=[AggregationFunctionType.categorical_count],
)
}
@ -154,3 +155,67 @@ class TestScoring:
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_aggregation_functions(
self, scoring_stack, sample_judge_prompt_template, judge_model
):
(
scoring_impl,
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl)
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
scoring_functions = {}
aggr_fns = [
AggregationFunctionType.accuracy,
AggregationFunctionType.median,
AggregationFunctionType.categorical_count,
AggregationFunctionType.average,
]
for x in scoring_fns_list:
if x.provider_id == "llm-as-judge":
aggr_fns = [AggregationFunctionType.categorical_count]
scoring_functions[x.identifier] = LLMAsJudgeScoringFnParams(
judge_model=judge_model,
prompt_template=sample_judge_prompt_template,
judge_score_regexes=[r"Score: (\d+)"],
aggregation_functions=aggr_fns,
)
elif x.provider_id == "basic":
if "regex_parser" in x.identifier:
scoring_functions[x.identifier] = RegexParserScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = BasicScoringFnParams(
aggregation_functions=aggr_fns,
)
else:
scoring_functions[x.identifier] = None
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
assert len(response.results[x].aggregated_results) == len(aggr_fns)

View file

@ -27,7 +27,8 @@ def supported_inference_models() -> List[Model]:
m
for m in all_registered_models()
if (
m.model_family in {ModelFamily.llama3_1, ModelFamily.llama3_2}
m.model_family
in {ModelFamily.llama3_1, ModelFamily.llama3_2, ModelFamily.llama3_3}
or is_supported_safety_model(m)
)
]

View file

@ -45,6 +45,13 @@ def get_embedding_model(model: str) -> "SentenceTransformer":
return loaded_model
def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
def parse_data_url(data_url: str):
data_url_pattern = re.compile(
r"^"
@ -88,10 +95,7 @@ def content_from_data(data_url: str) -> str:
return data.decode(encoding)
elif mime_type == "application/pdf":
# For PDF and DOC/DOCX files, we can't reliably convert to string)
pdf_bytes = io.BytesIO(data)
pdf_reader = PdfReader(pdf_bytes)
return "\n".join([page.extract_text() for page in pdf_reader.pages])
return parse_pdf(data)
else:
log.error("Could not extract content from data_url properly.")
@ -105,6 +109,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
pattern = re.compile("^(https?://|file://|data:)")
@ -114,6 +121,9 @@ async def content_from_doc(doc: MemoryBankDocument) -> str:
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content)
if doc.mime_type == "application/pdf":
return parse_pdf(r.content)
else:
return r.text
return interleaved_text_media_as_str(doc.content)

View file

@ -3,9 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import statistics
from typing import Any, Dict, List
from llama_stack.apis.scoring import ScoringResultRow
from llama_stack.apis.scoring import AggregationFunctionType, ScoringResultRow
def aggregate_accuracy(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
@ -26,3 +27,38 @@ def aggregate_average(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]
)
/ len([_ for _ in scoring_results if _["score"] is not None]),
}
def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}
def aggregate_median(scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
scores = [r["score"] for r in scoring_results if r["score"] is not None]
median = statistics.median(scores) if scores else None
return {"median": median}
# TODO: decide whether we want to make aggregation functions as a registerable resource
AGGREGATION_FUNCTIONS = {
AggregationFunctionType.accuracy: aggregate_accuracy,
AggregationFunctionType.average: aggregate_average,
AggregationFunctionType.categorical_count: aggregate_categorical_count,
AggregationFunctionType.median: aggregate_median,
}
def aggregate_metrics(
scoring_results: List[ScoringResultRow], metrics: List[AggregationFunctionType]
) -> Dict[str, Any]:
agg_results = {}
for metric in metrics:
if metric not in AGGREGATION_FUNCTIONS:
raise ValueError(f"Aggregation function {metric} not found")
agg_fn = AGGREGATION_FUNCTIONS[metric]
agg_results[metric] = agg_fn(scoring_results)
return agg_results

View file

@ -8,11 +8,12 @@ from typing import Any, Dict, List, Optional
from llama_stack.apis.scoring import ScoringFnParams, ScoringResultRow
from llama_stack.apis.scoring_functions import ScoringFn
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics
class BaseScoringFn(ABC):
"""
Base interface class for all meta-reference scoring_fns.
Base interface class for all native scoring_fns.
Each scoring_fn needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scoring_fn_results)
@ -44,11 +45,27 @@ class BaseScoringFn(ABC):
) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
async def aggregate(
self, scoring_results: List[ScoringResultRow]
self,
scoring_results: List[ScoringResultRow],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> Dict[str, Any]:
raise NotImplementedError()
params = self.supported_fn_defs_registry[scoring_fn_identifier].params
if scoring_params is not None:
if params is None:
params = scoring_params
else:
params.aggregation_functions = scoring_params.aggregation_functions
aggregation_functions = []
if (
params
and hasattr(params, "aggregation_functions")
and params.aggregation_functions
):
aggregation_functions.extend(params.aggregation_functions)
return aggregate_metrics(scoring_results, aggregation_functions)
async def score(
self,

View file

@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.telemetry import QueryCondition, Span, SpanWithChildren
class TelemetryDatasetMixin:
"""Mixin class that provides dataset-related functionality for telemetry providers."""
datasetio_api: DatasetIO
async def save_spans_to_dataset(
self,
attribute_filters: List[QueryCondition],
attributes_to_save: List[str],
dataset_id: str,
max_depth: Optional[int] = None,
) -> None:
spans = await self.query_spans(
attribute_filters=attribute_filters,
attributes_to_return=attributes_to_save,
max_depth=max_depth,
)
rows = [
{
"trace_id": span.trace_id,
"span_id": span.span_id,
"parent_span_id": span.parent_span_id,
"name": span.name,
"start_time": span.start_time,
"end_time": span.end_time,
**{attr: span.attributes.get(attr) for attr in attributes_to_save},
}
for span in spans
]
await self.datasetio_api.append_rows(dataset_id=dataset_id, rows=rows)
async def query_spans(
self,
attribute_filters: List[QueryCondition],
attributes_to_return: List[str],
max_depth: Optional[int] = None,
) -> List[Span]:
traces = await self.query_traces(attribute_filters=attribute_filters)
spans = []
for trace in traces:
span_tree = await self.get_span_tree(
span_id=trace.root_span_id,
attributes_to_return=attributes_to_return,
max_depth=max_depth,
)
def extract_spans(span: SpanWithChildren) -> List[Span]:
result = []
if span.attributes and all(
attr in span.attributes and span.attributes[attr] is not None
for attr in attributes_to_return
):
result.append(
Span(
trace_id=trace.root_span_id,
span_id=span.span_id,
parent_span_id=span.parent_span_id,
name=span.name,
start_time=span.start_time,
end_time=span.end_time,
attributes=span.attributes,
)
)
for child in span.children:
result.extend(extract_spans(child))
return result
spans.extend(extract_spans(span_tree))
return spans

View file

@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import datetime
from typing import List, Optional, Protocol
import aiosqlite
from llama_stack.apis.telemetry import QueryCondition, SpanWithChildren, Trace
class TraceStore(Protocol):
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]: ...
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren: ...
class SQLiteTraceStore(TraceStore):
def __init__(self, conn_string: str):
self.conn_string = conn_string
async def query_traces(
self,
attribute_filters: Optional[List[QueryCondition]] = None,
limit: Optional[int] = 100,
offset: Optional[int] = 0,
order_by: Optional[List[str]] = None,
) -> List[Trace]:
def build_where_clause() -> tuple[str, list]:
if not attribute_filters:
return "", []
ops_map = {"eq": "=", "ne": "!=", "gt": ">", "lt": "<"}
conditions = [
f"json_extract(s.attributes, '$.{condition.key}') {ops_map[condition.op.value]} ?"
for condition in attribute_filters
]
params = [condition.value for condition in attribute_filters]
where_clause = " WHERE " + " AND ".join(conditions)
return where_clause, params
def build_order_clause() -> str:
if not order_by:
return ""
order_clauses = []
for field in order_by:
desc = field.startswith("-")
clean_field = field[1:] if desc else field
order_clauses.append(f"t.{clean_field} {'DESC' if desc else 'ASC'}")
return " ORDER BY " + ", ".join(order_clauses)
# Build the main query
base_query = """
WITH matching_traces AS (
SELECT DISTINCT t.trace_id
FROM traces t
JOIN spans s ON t.trace_id = s.trace_id
{where_clause}
),
filtered_traces AS (
SELECT t.trace_id, t.root_span_id, t.start_time, t.end_time
FROM matching_traces mt
JOIN traces t ON mt.trace_id = t.trace_id
LEFT JOIN spans s ON t.trace_id = s.trace_id
{order_clause}
)
SELECT DISTINCT trace_id, root_span_id, start_time, end_time
FROM filtered_traces
LIMIT {limit} OFFSET {offset}
"""
where_clause, params = build_where_clause()
query = base_query.format(
where_clause=where_clause,
order_clause=build_order_clause(),
limit=limit,
offset=offset,
)
# Execute query and return results
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, params) as cursor:
rows = await cursor.fetchall()
return [
Trace(
trace_id=row["trace_id"],
root_span_id=row["root_span_id"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
)
for row in rows
]
async def get_span_tree(
self,
span_id: str,
attributes_to_return: Optional[List[str]] = None,
max_depth: Optional[int] = None,
) -> SpanWithChildren:
# Build the attributes selection
attributes_select = "s.attributes"
if attributes_to_return:
json_object = ", ".join(
f"'{key}', json_extract(s.attributes, '$.{key}')"
for key in attributes_to_return
)
attributes_select = f"json_object({json_object})"
# SQLite CTE query with filtered attributes
query = f"""
WITH RECURSIVE span_tree AS (
SELECT s.*, 1 as depth, {attributes_select} as filtered_attributes
FROM spans s
WHERE s.span_id = ?
UNION ALL
SELECT s.*, st.depth + 1, {attributes_select} as filtered_attributes
FROM spans s
JOIN span_tree st ON s.parent_span_id = st.span_id
WHERE (? IS NULL OR st.depth < ?)
)
SELECT *
FROM span_tree
ORDER BY depth, start_time
"""
async with aiosqlite.connect(self.conn_string) as conn:
conn.row_factory = aiosqlite.Row
async with conn.execute(query, (span_id, max_depth, max_depth)) as cursor:
rows = await cursor.fetchall()
if not rows:
raise ValueError(f"Span {span_id} not found")
# Build span tree
spans_by_id = {}
root_span = None
for row in rows:
span = SpanWithChildren(
span_id=row["span_id"],
trace_id=row["trace_id"],
parent_span_id=row["parent_span_id"],
name=row["name"],
start_time=datetime.fromisoformat(row["start_time"]),
end_time=datetime.fromisoformat(row["end_time"]),
attributes=json.loads(row["filtered_attributes"]),
status=row["status"].lower(),
children=[],
)
spans_by_id[span.span_id] = span
if span.span_id == span_id:
root_span = span
elif span.parent_span_id in spans_by_id:
spans_by_id[span.parent_span_id].children.append(span)
return root_span

View file

@ -0,0 +1,141 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import inspect
from datetime import datetime
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
from uuid import UUID
from pydantic import BaseModel
T = TypeVar("T")
def serialize_value(value: Any) -> Any:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return None
elif isinstance(value, (str, int, float, bool)):
return value
elif isinstance(value, BaseModel):
return value.model_dump()
elif isinstance(value, (list, tuple, set)):
return [serialize_value(item) for item in value]
elif isinstance(value, dict):
return {str(k): serialize_value(v) for k, v in value.items()}
elif isinstance(value, (datetime, UUID)):
return str(value)
else:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
"""
A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes.
"""
def trace_method(method: Callable) -> Callable:
from llama_stack.providers.utils.telemetry import tracing
is_async = asyncio.iscoroutinefunction(method)
is_async_gen = inspect.isasyncgenfunction(method)
def create_span_context(self: Any, *args: Any, **kwargs: Any) -> tuple:
class_name = self.__class__.__name__
method_name = method.__name__
span_type = (
"async_generator" if is_async_gen else "async" if is_async else "sync"
)
sig = inspect.signature(method)
param_names = list(sig.parameters.keys())[1:] # Skip 'self'
combined_args = {}
for i, arg in enumerate(args):
param_name = (
param_names[i] if i < len(param_names) else f"position_{i+1}"
)
combined_args[param_name] = serialize_value(arg)
for k, v in kwargs.items():
combined_args[str(k)] = serialize_value(v)
span_attributes = {
"__autotraced__": True,
"__class__": class_name,
"__method__": method_name,
"__type__": span_type,
"__args__": str(combined_args),
}
return class_name, method_name, span_attributes
@wraps(method)
async def async_gen_wrapper(
self: Any, *args: Any, **kwargs: Any
) -> AsyncGenerator:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
count = 0
async for item in method(self, *args, **kwargs):
yield item
count += 1
finally:
span.set_attribute("chunk_count", count)
@wraps(method)
async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = await method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as e:
span.set_attribute("error", str(e))
raise
@wraps(method)
def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any:
class_name, method_name, span_attributes = create_span_context(
self, *args, **kwargs
)
with tracing.span(f"{class_name}.{method_name}", span_attributes) as span:
try:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as _e:
raise
if is_async_gen:
return async_gen_wrapper
elif is_async:
return async_wrapper
else:
return sync_wrapper
original_init_subclass = getattr(cls, "__init_subclass__", None)
def __init_subclass__(cls_child, **kwargs): # noqa: N807
if original_init_subclass:
original_init_subclass(**kwargs)
for name, method in vars(cls_child).items():
if inspect.isfunction(method) and not name.startswith("_"):
setattr(cls_child, name, trace_method(method)) # noqa: B010
cls.__init_subclass__ = classmethod(__init_subclass__)
return cls

View file

@ -69,7 +69,7 @@ class TraceContext:
self.logger = logger
self.trace_id = trace_id
def push_span(self, name: str, attributes: Dict[str, Any] = None):
def push_span(self, name: str, attributes: Dict[str, Any] = None) -> Span:
current_span = self.get_current_span()
span = Span(
span_id=generate_short_uuid(),
@ -94,6 +94,7 @@ class TraceContext:
)
self.spans.append(span)
return span
def pop_span(self, status: SpanStatus = SpanStatus.OK):
span = self.spans.pop()
@ -203,12 +204,13 @@ class SpanContextManager:
def __init__(self, name: str, attributes: Dict[str, Any] = None):
self.name = name
self.attributes = attributes
self.span = None
def __enter__(self):
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.push_span(self.name, self.attributes)
self.span = context.push_span(self.name, self.attributes)
return self
def __exit__(self, exc_type, exc_value, traceback):
@ -217,11 +219,24 @@ class SpanContextManager:
if context:
context.pop_span()
def set_attribute(self, key: str, value: Any):
if self.span:
if self.span.attributes is None:
self.span.attributes = {}
self.span.attributes[key] = value
async def __aenter__(self):
return self.__enter__()
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
self.span = context.push_span(self.name, self.attributes)
return self
async def __aexit__(self, exc_type, exc_value, traceback):
self.__exit__(exc_type, exc_value, traceback)
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
context.pop_span()
def __call__(self, func: Callable):
@wraps(func)
@ -246,3 +261,11 @@ class SpanContextManager:
def span(name: str, attributes: Dict[str, Any] = None):
return SpanContextManager(name, attributes)
def get_current_span() -> Optional[Span]:
global CURRENT_TRACE_CONTEXT
context = CURRENT_TRACE_CONTEXT
if context:
return context.get_current_span()
return None