mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-18 13:39:48 +00:00
Merge branch 'main' into post_training_v3
This commit is contained in:
commit
e2a0dce8ad
286 changed files with 13314 additions and 4467 deletions
|
|
@ -3,3 +3,5 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
#
|
||||
# from .distribution.library_client import LlamaStackAsLibraryClient, AsyncLlamaStackAsLibraryClient
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.deployment_types import * # noqa: F403
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
|
@ -339,9 +340,8 @@ class AgentTurnResponseStepProgressPayload(BaseModel):
|
|||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
model_response_text_delta: Optional[str] = None
|
||||
text_delta: Optional[str] = None
|
||||
tool_call_delta: Optional[ToolCallDelta] = None
|
||||
tool_response_text_delta: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -418,6 +418,7 @@ class AgentStepResponse(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Agents(Protocol):
|
||||
@webmethod(route="/agents/create")
|
||||
async def create_agent(
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ class EventLogger:
|
|||
else:
|
||||
yield event, LogEvent(
|
||||
role=None,
|
||||
content=event.payload.model_response_text_delta,
|
||||
content=event.payload.text_delta,
|
||||
end="",
|
||||
color="yellow",
|
||||
)
|
||||
|
|
@ -171,12 +171,14 @@ class EventLogger:
|
|||
and event_type == EventType.step_complete.value
|
||||
):
|
||||
details = event.payload.step_details
|
||||
content = interleaved_text_media_as_str(details.inserted_context)
|
||||
content = content[:200] + "..." if len(content) > 200 else content
|
||||
inserted_context = interleaved_text_media_as_str(
|
||||
details.inserted_context
|
||||
)
|
||||
content = f"fetched {len(inserted_context)} bytes from {details.memory_bank_ids}"
|
||||
|
||||
yield event, LogEvent(
|
||||
role=step_type,
|
||||
content=f"Retrieved context from banks: {details.memory_bank_ids}.\n====\n{content}\n>",
|
||||
content=content,
|
||||
color="cyan",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -37,3 +37,8 @@ class DatasetIO(Protocol):
|
|||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult: ...
|
||||
|
||||
@webmethod(route="/datasetio/append-rows", method="POST")
|
||||
async def append_rows(
|
||||
self, dataset_id: str, rows: List[Dict[str, Any]]
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -78,6 +78,21 @@ class DatasetsClient(Datasets):
|
|||
|
||||
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.delete(
|
||||
f"{self.base_url}/datasets/unregister",
|
||||
params={
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
|
|
|||
|
|
@ -64,3 +64,9 @@ class Datasets(Protocol):
|
|||
|
||||
@webmethod(route="/datasets/list", method="GET")
|
||||
async def list_datasets(self) -> List[Dataset]: ...
|
||||
|
||||
@webmethod(route="/datasets/unregister", method="POST")
|
||||
async def unregister_dataset(
|
||||
self,
|
||||
dataset_id: str,
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
|
|
@ -220,6 +222,7 @@ class ModelStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Inference(Protocol):
|
||||
model_store: ModelStore
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -43,6 +44,7 @@ class MemoryBankStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Memory(Protocol):
|
||||
memory_bank_store: MemoryBankStore
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -88,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
|
|||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||
embedding_model: str
|
||||
chunk_size_in_tokens: int
|
||||
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
||||
overlap_size_in_tokens: Optional[int] = None
|
||||
|
||||
|
||||
|
|
@ -129,6 +131,7 @@ class MemoryBankInput(BaseModel):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class MemoryBanks(Protocol):
|
||||
@webmethod(route="/memory-banks/list", method="GET")
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class ModelsClient(Models):
|
|||
response = await client.post(
|
||||
f"{self.base_url}/models/register",
|
||||
json={
|
||||
"model": json.loads(model.json()),
|
||||
"model": json.loads(model.model_dump_json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,12 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
class CommonModelFields(BaseModel):
|
||||
|
|
@ -19,6 +21,11 @@ class CommonModelFields(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
llm = "llm"
|
||||
embedding_model = "embedding"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Model(CommonModelFields, Resource):
|
||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||
|
|
@ -33,16 +40,19 @@ class Model(CommonModelFields, Resource):
|
|||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
model_type: ModelType = Field(default=ModelType.llm)
|
||||
|
||||
|
||||
class ModelInput(CommonModelFields):
|
||||
model_id: str
|
||||
provider_id: Optional[str] = None
|
||||
provider_model_id: Optional[str] = None
|
||||
|
||||
model_type: Optional[ModelType] = ModelType.llm
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Models(Protocol):
|
||||
@webmethod(route="/models/list", method="GET")
|
||||
async def list_models(self) -> List[Model]: ...
|
||||
|
|
@ -57,6 +67,7 @@ class Models(Protocol):
|
|||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Model: ...
|
||||
|
||||
@webmethod(route="/models/unregister", method="POST")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Any, Dict, List, Optional, Protocol, Union
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.common.job_types import JobStatus
|
||||
|
|
@ -80,6 +81,11 @@ class QATFinetuningConfig(BaseModel):
|
|||
group_size: int
|
||||
|
||||
|
||||
AlgorithmConfig = Annotated[
|
||||
Union[LoraFinetuningConfig, LoraFinetuningConfig], Field(discriminator="type")
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class PostTrainingJobLogStream(BaseModel):
|
||||
"""Stream of logs from a finetuning job."""
|
||||
|
|
@ -166,9 +172,7 @@ class PostTraining(Protocol):
|
|||
description="Model descriptor from `llama model list`",
|
||||
),
|
||||
checkpoint_dir: Optional[str] = None,
|
||||
algorithm_config: Optional[
|
||||
Union[LoraFinetuningConfig, QATFinetuningConfig]
|
||||
] = None,
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
) -> PostTrainingJob: ...
|
||||
|
||||
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||
|
|
|
|||
|
|
@ -17,6 +17,8 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
|
||||
from llama_stack.distribution.datatypes import RemoteProviderConfig
|
||||
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
|
|
@ -45,7 +47,7 @@ class SafetyClient(Safety):
|
|||
) -> RunShieldResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/safety/run_shield",
|
||||
f"{self.base_url}/{LLAMA_STACK_API_VERSION}/safety/run-shield",
|
||||
json=dict(
|
||||
shield_id=shield_id,
|
||||
messages=[encodable_dict(m) for m in messages],
|
||||
|
|
@ -91,7 +93,7 @@ async def run_main(host: str, port: int, image_path: str = None):
|
|||
]:
|
||||
cprint(f"User>{message.content}", "green")
|
||||
response = await client.run_shield(
|
||||
shield_id="llama_guard",
|
||||
shield_id="meta-llama/Llama-Guard-3-1B",
|
||||
messages=[message],
|
||||
)
|
||||
print(response)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
|||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.shields import * # noqa: F403
|
||||
|
||||
|
|
@ -43,6 +45,7 @@ class ShieldStore(Protocol):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Safety(Protocol):
|
||||
shield_store: ShieldStore
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,15 @@ from llama_stack.apis.resource import Resource, ResourceType
|
|||
class ScoringFnParamsType(Enum):
|
||||
llm_as_judge = "llm_as_judge"
|
||||
regex_parser = "regex_parser"
|
||||
basic = "basic"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AggregationFunctionType(Enum):
|
||||
average = "average"
|
||||
median = "median"
|
||||
categorical_count = "categorical_count"
|
||||
accuracy = "accuracy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -44,6 +53,10 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
|||
description="Regexes to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -55,12 +68,26 @@ class RegexParserScoringFnParams(BaseModel):
|
|||
description="Regex to extract the answer from generated response",
|
||||
default_factory=list,
|
||||
)
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class BasicScoringFnParams(BaseModel):
|
||||
type: Literal[ScoringFnParamsType.basic.value] = ScoringFnParamsType.basic.value
|
||||
aggregation_functions: Optional[List[AggregationFunctionType]] = Field(
|
||||
description="Aggregation functions to apply to the scores of each row",
|
||||
default_factory=list,
|
||||
)
|
||||
|
||||
|
||||
ScoringFnParams = Annotated[
|
||||
Union[
|
||||
LLMAsJudgeScoringFnParams,
|
||||
RegexParserScoringFnParams,
|
||||
BasicScoringFnParams,
|
||||
],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from llama_models.schema_utils import json_schema_type, webmethod
|
|||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.resource import Resource, ResourceType
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
|
||||
|
||||
class CommonShieldFields(BaseModel):
|
||||
|
|
@ -38,6 +39,7 @@ class ShieldInput(CommonShieldFields):
|
|||
|
||||
|
||||
@runtime_checkable
|
||||
@trace_protocol
|
||||
class Shields(Protocol):
|
||||
@webmethod(route="/shields/list", method="GET")
|
||||
async def list_shields(self) -> List[Shield]: ...
|
||||
|
|
|
|||
|
|
@ -6,12 +6,24 @@
|
|||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Literal, Optional, Protocol, runtime_checkable, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
# Add this constant near the top of the file, after the imports
|
||||
DEFAULT_TTL_DAYS = 7
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanStatus(Enum):
|
||||
|
|
@ -29,6 +41,11 @@ class Span(BaseModel):
|
|||
end_time: Optional[datetime] = None
|
||||
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
def set_attribute(self, key: str, value: Any):
|
||||
if self.attributes is None:
|
||||
self.attributes = {}
|
||||
self.attributes[key] = value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Trace(BaseModel):
|
||||
|
|
@ -123,10 +140,73 @@ Event = Annotated[
|
|||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class EvalTrace(BaseModel):
|
||||
session_id: str
|
||||
step: str
|
||||
input: str
|
||||
output: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SpanWithChildren(Span):
|
||||
children: List["SpanWithChildren"] = Field(default_factory=list)
|
||||
status: Optional[SpanStatus] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryConditionOp(Enum):
|
||||
EQ = "eq"
|
||||
NE = "ne"
|
||||
GT = "gt"
|
||||
LT = "lt"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class QueryCondition(BaseModel):
|
||||
key: str
|
||||
op: QueryConditionOp
|
||||
value: Any
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Telemetry(Protocol):
|
||||
@webmethod(route="/telemetry/log-event")
|
||||
async def log_event(self, event: Event) -> None: ...
|
||||
async def log_event(
|
||||
self, event: Event, ttl_seconds: int = DEFAULT_TTL_DAYS * 86400
|
||||
) -> None: ...
|
||||
|
||||
@webmethod(route="/telemetry/get-trace", method="GET")
|
||||
async def get_trace(self, trace_id: str) -> Trace: ...
|
||||
@webmethod(route="/telemetry/query-traces", method="POST")
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]: ...
|
||||
|
||||
@webmethod(route="/telemetry/get-span-tree", method="POST")
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren: ...
|
||||
|
||||
@webmethod(route="/telemetry/query-spans", method="POST")
|
||||
async def query_spans(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
attributes_to_return: List[str],
|
||||
max_depth: Optional[int] = None,
|
||||
) -> List[Span]: ...
|
||||
|
||||
@webmethod(route="/telemetry/save-spans-to-dataset", method="POST")
|
||||
async def save_spans_to_dataset(
|
||||
self,
|
||||
attribute_filters: List[QueryCondition],
|
||||
attributes_to_save: List[str],
|
||||
dataset_id: str,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> None: ...
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from llama_stack.distribution.distribution import get_provider_registry
|
|||
from llama_stack.distribution.resolver import InvalidProviderError
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
|
||||
TEMPLATES_PATH = Path(os.path.relpath(__file__)).parent.parent.parent / "templates"
|
||||
TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates"
|
||||
|
||||
|
||||
@lru_cache()
|
||||
|
|
@ -51,7 +51,7 @@ class StackBuild(Subcommand):
|
|||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/example_configs. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
help="Path to a config file to use for the build. You can find example configs in llama_stack/distribution/**/build.yaml. If this argument is not provided, you will be prompted to enter information interactively",
|
||||
)
|
||||
|
||||
self.parser.add_argument(
|
||||
|
|
@ -73,7 +73,7 @@ class StackBuild(Subcommand):
|
|||
"--image-type",
|
||||
type=str,
|
||||
help="Image Type to use for the build. This can be either conda or docker. If not specified, will use the image type from the template config.",
|
||||
choices=["conda", "docker"],
|
||||
choices=["conda", "docker", "venv"],
|
||||
default="conda",
|
||||
)
|
||||
|
||||
|
|
@ -124,8 +124,8 @@ class StackBuild(Subcommand):
|
|||
image_type = prompt(
|
||||
"> Enter the image type you want your Llama Stack to be built as (docker or conda): ",
|
||||
validator=Validator.from_callable(
|
||||
lambda x: x in ["docker", "conda"],
|
||||
error_message="Invalid image type, please enter conda or docker",
|
||||
lambda x: x in ["docker", "conda", "venv"],
|
||||
error_message="Invalid image type, please enter conda or docker or venv",
|
||||
),
|
||||
default="conda",
|
||||
)
|
||||
|
|
@ -261,7 +261,6 @@ class StackBuild(Subcommand):
|
|||
) -> None:
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
|
||||
import yaml
|
||||
from termcolor import cprint
|
||||
|
|
@ -291,20 +290,8 @@ class StackBuild(Subcommand):
|
|||
run_config_file = build_dir / f"{build_config.name}-run.yaml"
|
||||
shutil.copy(template_path, run_config_file)
|
||||
|
||||
with open(template_path, "r") as f:
|
||||
yaml_content = f.read()
|
||||
|
||||
# Find all ${env.VARIABLE} patterns
|
||||
env_vars = set(re.findall(r"\${env\.([A-Za-z0-9_]+)}", yaml_content))
|
||||
cprint("Build Successful! Next steps: ", color="green")
|
||||
cprint(
|
||||
f" 1. Set the environment variables: {list(env_vars)}",
|
||||
color="green",
|
||||
)
|
||||
cprint(
|
||||
f" 2. Run: `llama stack run {template_name}`",
|
||||
color="green",
|
||||
)
|
||||
cprint("Build Successful!", color="green")
|
||||
else:
|
||||
self._generate_run_config(build_config, build_dir)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from typing import List
|
|||
|
||||
import pkg_resources
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.utils.exec import run_with_pty
|
||||
|
||||
|
|
@ -37,6 +38,7 @@ SERVER_DEPENDENCIES = [
|
|||
class ImageType(Enum):
|
||||
docker = "docker"
|
||||
conda = "conda"
|
||||
venv = "venv"
|
||||
|
||||
|
||||
class ApiInput(BaseModel):
|
||||
|
|
@ -45,7 +47,7 @@ class ApiInput(BaseModel):
|
|||
|
||||
|
||||
def get_provider_dependencies(
|
||||
config_providers: Dict[str, List[Provider]]
|
||||
config_providers: Dict[str, List[Provider]],
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Get normal and special dependencies from provider configuration."""
|
||||
all_providers = get_provider_registry()
|
||||
|
|
@ -90,11 +92,12 @@ def get_provider_dependencies(
|
|||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
||||
|
||||
print(
|
||||
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
|
||||
cprint(
|
||||
f"Please install needed dependencies using the following commands:\n\npip install {' '.join(normal_deps)}",
|
||||
"yellow",
|
||||
)
|
||||
for special_dep in special_deps:
|
||||
log.info(f"\tpip install {special_dep}")
|
||||
cprint(f"pip install {special_dep}", "yellow")
|
||||
print()
|
||||
|
||||
|
||||
|
|
@ -118,7 +121,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
str(BUILDS_BASE_DIR / ImageType.docker.value),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
else:
|
||||
elif build_config.image_type == ImageType.conda.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/build_conda_env.sh"
|
||||
)
|
||||
|
|
@ -128,6 +131,16 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
|
|||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
elif build_config.image_type == ImageType.venv.value:
|
||||
script = pkg_resources.resource_filename(
|
||||
"llama_stack", "distribution/build_venv.sh"
|
||||
)
|
||||
args = [
|
||||
script,
|
||||
build_config.name,
|
||||
str(build_file_path),
|
||||
" ".join(normal_deps),
|
||||
]
|
||||
|
||||
if special_deps:
|
||||
args.append("#".join(special_deps))
|
||||
|
|
|
|||
105
llama_stack/distribution/build_venv.sh
Executable file
105
llama_stack/distribution/build_venv.sh
Executable file
|
|
@ -0,0 +1,105 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# TODO: combine this with build_conda_env.sh since it is almost identical
|
||||
# the only difference is that we don't do any conda-specific setup
|
||||
|
||||
LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-}
|
||||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
echo "Using llama-stack-dir=$LLAMA_STACK_DIR"
|
||||
fi
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
echo "Using llama-models-dir=$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
if [ "$#" -lt 3 ]; then
|
||||
echo "Usage: $0 <distribution_type> <build_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
|
||||
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
special_pip_deps="$4"
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
build_name="$1"
|
||||
env_name="llamastack-$build_name"
|
||||
build_file_path="$2"
|
||||
pip_dependencies="$3"
|
||||
|
||||
# Define color codes
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# this is set if we actually create a new conda in which case we need to clean up
|
||||
ENVNAME=""
|
||||
|
||||
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
|
||||
source "$SCRIPT_DIR/common.sh"
|
||||
|
||||
run() {
|
||||
local env_name="$1"
|
||||
local pip_dependencies="$2"
|
||||
local special_pip_deps="$3"
|
||||
|
||||
if [ -n "$TEST_PYPI_VERSION" ]; then
|
||||
# these packages are damaged in test-pypi, so install them first
|
||||
pip install fastapi libcst
|
||||
pip install --extra-index-url https://test.pypi.org/simple/ \
|
||||
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
|
||||
$pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
pip install $part
|
||||
done
|
||||
fi
|
||||
else
|
||||
# Re-installing llama-stack in the new conda environment
|
||||
if [ -n "$LLAMA_STACK_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_STACK_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
|
||||
pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
|
||||
else
|
||||
pip install --no-cache-dir llama-stack
|
||||
fi
|
||||
|
||||
if [ -n "$LLAMA_MODELS_DIR" ]; then
|
||||
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
|
||||
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
|
||||
pip uninstall -y llama-models
|
||||
pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
|
||||
fi
|
||||
|
||||
# Install pip dependencies
|
||||
printf "Installing pip dependencies\n"
|
||||
pip install $pip_dependencies
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
echo "$part"
|
||||
pip install $part
|
||||
done
|
||||
fi
|
||||
fi
|
||||
}
|
||||
|
||||
run "$env_name" "$pip_dependencies" "$special_pip_deps"
|
||||
|
|
@ -165,5 +165,5 @@ class BuildConfig(BaseModel):
|
|||
)
|
||||
image_type: str = Field(
|
||||
default="conda",
|
||||
description="Type of package to build (conda | container)",
|
||||
description="Type of package to build (conda | docker | venv)",
|
||||
)
|
||||
|
|
|
|||
331
llama_stack/distribution/library_client.py
Normal file
331
llama_stack/distribution/library_client.py
Normal file
|
|
@ -0,0 +1,331 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator, get_args, get_origin, Optional, Type, TypeVar, Union
|
||||
|
||||
import yaml
|
||||
from llama_stack_client import AsyncLlamaStackClient, LlamaStackClient, NOT_GIVEN
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
from rich.console import Console
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.distribution.build import print_pip_install_help
|
||||
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.distribution.resolver import ProviderRegistry
|
||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||
from llama_stack.distribution.stack import (
|
||||
construct_stack,
|
||||
get_stack_run_config_from_template,
|
||||
replace_env_vars,
|
||||
)
|
||||
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def in_notebook():
|
||||
try:
|
||||
from IPython import get_ipython
|
||||
|
||||
if "IPKernelApp" not in get_ipython().config: # pragma: no cover
|
||||
return False
|
||||
except ImportError:
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def stream_across_asyncio_run_boundary(
|
||||
async_gen_maker,
|
||||
pool_executor: ThreadPoolExecutor,
|
||||
) -> Generator[T, None, None]:
|
||||
result_queue = queue.Queue()
|
||||
stop_event = threading.Event()
|
||||
|
||||
async def consumer():
|
||||
# make sure we make the generator in the event loop context
|
||||
gen = await async_gen_maker()
|
||||
try:
|
||||
async for item in gen:
|
||||
result_queue.put(item)
|
||||
except Exception as e:
|
||||
print(f"Error in generator {e}")
|
||||
result_queue.put(e)
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
finally:
|
||||
result_queue.put(StopIteration)
|
||||
stop_event.set()
|
||||
|
||||
def run_async():
|
||||
# Run our own loop to avoid double async generator cleanup which is done
|
||||
# by asyncio.run()
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
task = loop.create_task(consumer())
|
||||
loop.run_until_complete(task)
|
||||
finally:
|
||||
# Handle pending tasks like a generator's athrow()
|
||||
pending = asyncio.all_tasks(loop)
|
||||
if pending:
|
||||
loop.run_until_complete(
|
||||
asyncio.gather(*pending, return_exceptions=True)
|
||||
)
|
||||
loop.close()
|
||||
|
||||
future = pool_executor.submit(run_async)
|
||||
|
||||
try:
|
||||
# yield results as they come in
|
||||
while not stop_event.is_set() or not result_queue.empty():
|
||||
try:
|
||||
item = result_queue.get(timeout=0.1)
|
||||
if item is StopIteration:
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
except queue.Empty:
|
||||
continue
|
||||
finally:
|
||||
future.result()
|
||||
|
||||
|
||||
def convert_pydantic_to_json_value(value: Any, cast_to: Type) -> dict:
|
||||
if isinstance(value, Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [convert_pydantic_to_json_value(item, cast_to) for item in value]
|
||||
elif isinstance(value, dict):
|
||||
return {k: convert_pydantic_to_json_value(v, cast_to) for k, v in value.items()}
|
||||
elif isinstance(value, BaseModel):
|
||||
# This is quite hacky and we should figure out how to use stuff from
|
||||
# generated client-sdk code (using ApiResponse.parse() essentially)
|
||||
value_dict = json.loads(value.model_dump_json())
|
||||
|
||||
origin = get_origin(cast_to)
|
||||
if origin is Union:
|
||||
args = get_args(cast_to)
|
||||
for arg in args:
|
||||
arg_name = arg.__name__.split(".")[-1]
|
||||
value_name = value.__class__.__name__.split(".")[-1]
|
||||
if arg_name == value_name:
|
||||
return arg(**value_dict)
|
||||
|
||||
# assume we have the correct association between the server-side type and the client-side type
|
||||
return cast_to(**value_dict)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def convert_to_pydantic(annotation: Any, value: Any) -> Any:
|
||||
if isinstance(annotation, type) and annotation in {str, int, float, bool}:
|
||||
return value
|
||||
|
||||
origin = get_origin(annotation)
|
||||
if origin is list:
|
||||
item_type = get_args(annotation)[0]
|
||||
try:
|
||||
return [convert_to_pydantic(item_type, item) for item in value]
|
||||
except Exception:
|
||||
print(f"Error converting list {value}")
|
||||
return value
|
||||
|
||||
elif origin is dict:
|
||||
key_type, val_type = get_args(annotation)
|
||||
try:
|
||||
return {k: convert_to_pydantic(val_type, v) for k, v in value.items()}
|
||||
except Exception:
|
||||
print(f"Error converting dict {value}")
|
||||
return value
|
||||
|
||||
try:
|
||||
# Handle Pydantic models and discriminated unions
|
||||
return TypeAdapter(annotation).validate_python(value)
|
||||
except Exception as e:
|
||||
cprint(
|
||||
f"Warning: direct client failed to convert parameter {value} into {annotation}: {e}",
|
||||
"yellow",
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||
config_path_or_template_name, custom_provider_registry
|
||||
)
|
||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
||||
|
||||
def initialize(self):
|
||||
if in_notebook():
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
return asyncio.run(self.async_client.initialize())
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
if kwargs.get("stream"):
|
||||
return stream_across_asyncio_run_boundary(
|
||||
lambda: self.async_client.request(*args, **kwargs),
|
||||
self.pool_executor,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(self.async_client.request(*args, **kwargs))
|
||||
|
||||
|
||||
class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
|
||||
def __init__(
|
||||
self,
|
||||
config_path_or_template_name: str,
|
||||
custom_provider_registry: Optional[ProviderRegistry] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# when using the library client, we should not log to console since many
|
||||
# of our logs are intended for server-side usage
|
||||
os.environ["TELEMETRY_SINKS"] = "sqlite"
|
||||
|
||||
if config_path_or_template_name.endswith(".yaml"):
|
||||
config_path = Path(config_path_or_template_name)
|
||||
if not config_path.exists():
|
||||
raise ValueError(f"Config file {config_path} does not exist")
|
||||
config_dict = replace_env_vars(yaml.safe_load(config_path.read_text()))
|
||||
config = parse_and_maybe_upgrade_config(config_dict)
|
||||
else:
|
||||
# template
|
||||
config = get_stack_run_config_from_template(config_path_or_template_name)
|
||||
|
||||
self.config_path_or_template_name = config_path_or_template_name
|
||||
self.config = config
|
||||
self.custom_provider_registry = custom_provider_registry
|
||||
|
||||
async def initialize(self):
|
||||
try:
|
||||
self.impls = await construct_stack(
|
||||
self.config, self.custom_provider_registry
|
||||
)
|
||||
except ModuleNotFoundError as _e:
|
||||
cprint(
|
||||
"Using llama-stack as a library requires installing dependencies depending on the template (providers) you choose.\n",
|
||||
"yellow",
|
||||
)
|
||||
if self.config_path_or_template_name.endswith(".yaml"):
|
||||
print_pip_install_help(self.config.providers)
|
||||
else:
|
||||
prefix = "!" if in_notebook() else ""
|
||||
cprint(
|
||||
f"Please run:\n\n{prefix}llama stack build --template {self.config_path_or_template_name} --image-type venv\n\n",
|
||||
"yellow",
|
||||
)
|
||||
return False
|
||||
|
||||
if Api.telemetry in self.impls:
|
||||
setup_logger(self.impls[Api.telemetry])
|
||||
|
||||
console = Console()
|
||||
console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:")
|
||||
console.print(yaml.dump(self.config.model_dump(), indent=2))
|
||||
|
||||
endpoints = get_all_api_endpoints()
|
||||
endpoint_impls = {}
|
||||
for api, api_endpoints in endpoints.items():
|
||||
for endpoint in api_endpoints:
|
||||
impl = self.impls[api]
|
||||
func = getattr(impl, endpoint.name)
|
||||
endpoint_impls[endpoint.route] = func
|
||||
|
||||
self.endpoint_impls = endpoint_impls
|
||||
return True
|
||||
|
||||
async def request(
|
||||
self,
|
||||
cast_to: Any,
|
||||
options: Any,
|
||||
*,
|
||||
stream=False,
|
||||
stream_cls=None,
|
||||
):
|
||||
if not self.endpoint_impls:
|
||||
raise ValueError("Client not initialized")
|
||||
|
||||
params = options.params or {}
|
||||
params |= options.json_data or {}
|
||||
if stream:
|
||||
return self._call_streaming(options.url, params, cast_to)
|
||||
else:
|
||||
return await self._call_non_streaming(options.url, params, cast_to)
|
||||
|
||||
async def _call_non_streaming(
|
||||
self, path: str, body: dict = None, cast_to: Any = None
|
||||
):
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
return convert_pydantic_to_json_value(await func(**body), cast_to)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
async def _call_streaming(self, path: str, body: dict = None, cast_to: Any = None):
|
||||
await start_trace(path, {"__location__": "library_client"})
|
||||
try:
|
||||
func = self.endpoint_impls.get(path)
|
||||
if not func:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
body = self._convert_body(path, body)
|
||||
async for chunk in await func(**body):
|
||||
yield convert_pydantic_to_json_value(chunk, cast_to)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
def _convert_body(self, path: str, body: Optional[dict] = None) -> dict:
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
func = self.endpoint_impls[path]
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Strip NOT_GIVENs to use the defaults in signature
|
||||
body = {k: v for k, v in body.items() if v is not NOT_GIVEN}
|
||||
|
||||
# Convert parameters to Pydantic models where needed
|
||||
converted_body = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
if param_name in body:
|
||||
value = body.get(param_name)
|
||||
converted_body[param_name] = convert_to_pydantic(
|
||||
param.annotation, value
|
||||
)
|
||||
return converted_body
|
||||
|
|
@ -35,7 +35,7 @@ class NeedsRequestProviderData:
|
|||
provider_data = validator(**val)
|
||||
return provider_data
|
||||
except Exception as e:
|
||||
log.error("Error parsing provider data", e)
|
||||
log.error(f"Error parsing provider data: {e}")
|
||||
|
||||
|
||||
def set_request_provider_data(headers: Dict[str, str]):
|
||||
|
|
|
|||
|
|
@ -88,9 +88,10 @@ class InferenceRouter(Inference):
|
|||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> None:
|
||||
await self.routing_table.register_model(
|
||||
model_id, provider_model_id, provider_id, metadata
|
||||
model_id, provider_model_id, provider_id, metadata, model_type
|
||||
)
|
||||
|
||||
async def chat_completion(
|
||||
|
|
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
messages=messages,
|
||||
|
|
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||
)
|
||||
provider = self.routing_table.get_provider_impl(model_id)
|
||||
params = dict(
|
||||
model_id=model_id,
|
||||
|
|
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
|
|||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
model = await self.routing_table.get_model(model_id)
|
||||
if model is None:
|
||||
raise ValueError(f"Model '{model_id}' not found")
|
||||
if model.model_type == ModelType.llm:
|
||||
raise ValueError(
|
||||
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||
)
|
||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||
model_id=model_id,
|
||||
contents=contents,
|
||||
|
|
@ -222,6 +244,12 @@ class DatasetIORouter(DatasetIO):
|
|||
filter_condition=filter_condition,
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
return await self.routing_table.get_provider_impl(dataset_id).append_rows(
|
||||
dataset_id=dataset_id,
|
||||
rows=rows,
|
||||
)
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ async def unregister_object_from_provider(obj: RoutableObject, p: Any) -> None:
|
|||
return await p.unregister_memory_bank(obj.identifier)
|
||||
elif api == Api.inference:
|
||||
return await p.unregister_model(obj.identifier)
|
||||
elif api == Api.datasetio:
|
||||
return await p.unregister_dataset(obj.identifier)
|
||||
else:
|
||||
raise ValueError(f"Unregister not supported for {api}")
|
||||
|
||||
|
|
@ -207,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
provider_model_id: Optional[str] = None,
|
||||
provider_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
model_type: Optional[ModelType] = None,
|
||||
) -> Model:
|
||||
if provider_model_id is None:
|
||||
provider_model_id = model_id
|
||||
|
|
@ -220,11 +223,21 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
)
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
if model_type is None:
|
||||
model_type = ModelType.llm
|
||||
if (
|
||||
"embedding_dimension" not in metadata
|
||||
and model_type == ModelType.embedding_model
|
||||
):
|
||||
raise ValueError(
|
||||
"Embedding model must have an embedding dimension in its metadata"
|
||||
)
|
||||
model = Model(
|
||||
identifier=model_id,
|
||||
provider_resource_id=provider_model_id,
|
||||
provider_id=provider_id,
|
||||
metadata=metadata,
|
||||
model_type=model_type,
|
||||
)
|
||||
registered_model = await self.register_object(model)
|
||||
return registered_model
|
||||
|
|
@ -296,16 +309,29 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
memory_bank = parse_obj_as(
|
||||
MemoryBank,
|
||||
{
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
},
|
||||
)
|
||||
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {params.embedding_model} not found")
|
||||
if model.model_type != ModelType.embedding_model:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} is not an embedding model"
|
||||
)
|
||||
if "embedding_dimension" not in model.metadata:
|
||||
raise ValueError(
|
||||
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||
)
|
||||
memory_bank_data = {
|
||||
"identifier": memory_bank_id,
|
||||
"type": ResourceType.memory_bank.value,
|
||||
"provider_id": provider_id,
|
||||
"provider_resource_id": provider_memory_bank_id,
|
||||
**params.model_dump(),
|
||||
}
|
||||
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||
"embedding_dimension"
|
||||
]
|
||||
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
||||
await self.register_object(memory_bank)
|
||||
return memory_bank
|
||||
|
||||
|
|
@ -354,6 +380,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
)
|
||||
await self.register_object(dataset)
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
dataset = await self.get_dataset(dataset_id)
|
||||
if dataset is None:
|
||||
raise ValueError(f"Dataset {dataset_id} not found")
|
||||
await self.unregister_object(dataset)
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]:
|
||||
|
|
|
|||
|
|
@ -17,13 +17,11 @@ import warnings
|
|||
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from ssl import SSLError
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Union
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
|
||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
|
@ -35,7 +33,6 @@ from llama_stack.distribution.distribution import builtin_automatically_routed_a
|
|||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
end_trace,
|
||||
setup_logger,
|
||||
SpanStatus,
|
||||
start_trace,
|
||||
)
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
|
@ -46,9 +43,9 @@ from llama_stack.distribution.stack import (
|
|||
replace_env_vars,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.providers.inline.meta_reference.telemetry.console import (
|
||||
ConsoleConfig,
|
||||
ConsoleTelemetryImpl,
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
|
||||
from .endpoints import get_all_api_endpoints
|
||||
|
|
@ -118,67 +115,6 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
|||
)
|
||||
|
||||
|
||||
async def passthrough(
|
||||
request: Request,
|
||||
downstream_url: str,
|
||||
downstream_headers: Optional[Dict[str, str]] = None,
|
||||
):
|
||||
await start_trace(request.path, {"downstream_url": downstream_url})
|
||||
|
||||
headers = dict(request.headers)
|
||||
headers.pop("host", None)
|
||||
headers.update(downstream_headers or {})
|
||||
|
||||
content = await request.body()
|
||||
|
||||
client = httpx.AsyncClient()
|
||||
erred = False
|
||||
try:
|
||||
req = client.build_request(
|
||||
method=request.method,
|
||||
url=downstream_url,
|
||||
headers=headers,
|
||||
content=content,
|
||||
params=request.query_params,
|
||||
)
|
||||
response = await client.send(req, stream=True)
|
||||
|
||||
async def stream_response():
|
||||
async for chunk in response.aiter_raw(chunk_size=64):
|
||||
yield chunk
|
||||
|
||||
await response.aclose()
|
||||
await client.aclose()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_response(),
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.headers.get("content-type"),
|
||||
)
|
||||
|
||||
except httpx.ReadTimeout:
|
||||
erred = True
|
||||
return Response(content="Downstream server timed out", status_code=504)
|
||||
except httpx.NetworkError as e:
|
||||
erred = True
|
||||
return Response(content=f"Network error: {str(e)}", status_code=502)
|
||||
except httpx.TooManyRedirects:
|
||||
erred = True
|
||||
return Response(content="Too many redirects", status_code=502)
|
||||
except SSLError as e:
|
||||
erred = True
|
||||
return Response(content=f"SSL error: {str(e)}", status_code=502)
|
||||
except httpx.HTTPStatusError as e:
|
||||
erred = True
|
||||
return Response(content=str(e), status_code=e.response.status_code)
|
||||
except Exception as e:
|
||||
erred = True
|
||||
return Response(content=f"Unexpected error: {str(e)}", status_code=500)
|
||||
finally:
|
||||
await end_trace(SpanStatus.OK if not erred else SpanStatus.ERROR)
|
||||
|
||||
|
||||
def handle_sigint(app, *args, **kwargs):
|
||||
print("SIGINT or CTRL-C detected. Exiting gracefully...")
|
||||
|
||||
|
|
@ -217,7 +153,6 @@ async def maybe_await(value):
|
|||
|
||||
|
||||
async def sse_generator(event_gen):
|
||||
await start_trace("sse_generator")
|
||||
try:
|
||||
event_gen = await event_gen
|
||||
async for item in event_gen:
|
||||
|
|
@ -235,14 +170,10 @@ async def sse_generator(event_gen):
|
|||
},
|
||||
}
|
||||
)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str):
|
||||
async def endpoint(request: Request, **kwargs):
|
||||
await start_trace(func.__name__)
|
||||
|
||||
set_request_provider_data(request.headers)
|
||||
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
|
@ -257,8 +188,6 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
except Exception as e:
|
||||
traceback.print_exception(e)
|
||||
raise translate_exception(e) from e
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
sig = inspect.signature(func)
|
||||
new_params = [
|
||||
|
|
@ -282,6 +211,19 @@ def create_dynamic_typed_route(func: Any, method: str):
|
|||
return endpoint
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
path = scope["path"]
|
||||
await start_trace(path, {"__location__": "server"})
|
||||
try:
|
||||
return await self.app(scope, receive, send)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
def main():
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
|
@ -338,6 +280,7 @@ def main():
|
|||
print(yaml.dump(config.model_dump(), indent=2))
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
app.add_middleware(TracingMiddleware)
|
||||
|
||||
try:
|
||||
impls = asyncio.run(construct_stack(config))
|
||||
|
|
@ -347,7 +290,7 @@ def main():
|
|||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(ConsoleTelemetryImpl(ConsoleConfig()))
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig()))
|
||||
|
||||
all_endpoints = get_all_api_endpoints()
|
||||
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
|
|||
|
||||
|
||||
REGISTER_PREFIX = "distributions:registry"
|
||||
KEY_VERSION = "v2"
|
||||
KEY_VERSION = "v3"
|
||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||
|
||||
|
||||
|
|
|
|||
128
llama_stack/distribution/tests/library_client_test.py
Normal file
128
llama_stack/distribution/tests/library_client_test.py
Normal file
|
|
@ -0,0 +1,128 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger as AgentEventLogger
|
||||
from llama_stack_client.lib.inference.event_logger import EventLogger
|
||||
from llama_stack_client.types import Attachment, UserMessage
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
|
||||
|
||||
def main(config_path: str):
|
||||
client = LlamaStackAsLibraryClient(config_path)
|
||||
if not client.initialize():
|
||||
return
|
||||
|
||||
models = client.models.list()
|
||||
print("\nModels:")
|
||||
for model in models:
|
||||
print(model)
|
||||
|
||||
if not models:
|
||||
print("No models found, skipping chat completion test")
|
||||
return
|
||||
|
||||
model_id = models[0].identifier
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=False,
|
||||
)
|
||||
print("\nChat completion response (non-stream):")
|
||||
print(response)
|
||||
|
||||
response = client.inference.chat_completion(
|
||||
messages=[UserMessage(content="What is the capital of France?", role="user")],
|
||||
model_id=model_id,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
print("\nChat completion response (stream):")
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
print("\nAgent test:")
|
||||
agent_config = AgentConfig(
|
||||
model=model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": "greedy",
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
tools=(
|
||||
[
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": os.getenv("BRAVE_SEARCH_API_KEY"),
|
||||
}
|
||||
]
|
||||
if os.getenv("BRAVE_SEARCH_API_KEY")
|
||||
else []
|
||||
)
|
||||
+ (
|
||||
[
|
||||
{
|
||||
"type": "code_interpreter",
|
||||
}
|
||||
]
|
||||
),
|
||||
tool_choice="required",
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
agent = Agent(client, agent_config)
|
||||
user_prompts = [
|
||||
"Hello",
|
||||
"Which players played in the winning team of the NBA western conference semifinals of 2024, please use tools",
|
||||
]
|
||||
user_prompts = [
|
||||
(
|
||||
"Here is a csv, can you describe it ?",
|
||||
[
|
||||
Attachment(
|
||||
content="https://raw.githubusercontent.com/meta-llama/llama-stack-apps/main/examples/resources/inflation.csv",
|
||||
mime_type="test/csv",
|
||||
)
|
||||
],
|
||||
),
|
||||
("Which year ended with the highest inflation ?", None),
|
||||
(
|
||||
"What macro economic situations that led to such high inflation in that period?",
|
||||
None,
|
||||
),
|
||||
("Plot average yearly inflation as a time series", None),
|
||||
]
|
||||
|
||||
session_id = agent.create_session("test-session")
|
||||
|
||||
for prompt, attachments in user_prompts:
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
attachments=attachments,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
for log in AgentEventLogger().log(response):
|
||||
log.print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("config_path", help="Path to the config YAML file")
|
||||
args = parser.parse_args()
|
||||
main(args.config_path)
|
||||
42
llama_stack/distribution/ui/README.md
Normal file
42
llama_stack/distribution/ui/README.md
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# (Experimental) LLama Stack UI
|
||||
|
||||
## Docker Setup
|
||||
|
||||
:warning: This is a work in progress.
|
||||
|
||||
## Developer Setup
|
||||
|
||||
1. Start up Llama Stack API server. More details [here](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
|
||||
|
||||
```
|
||||
llama stack build --template together --image-type conda
|
||||
|
||||
llama stack run together
|
||||
```
|
||||
|
||||
2. (Optional) Register datasets and eval tasks as resources. If you want to run pre-configured evaluation flows (e.g. Evaluations (Generation + Scoring) Page).
|
||||
|
||||
```bash
|
||||
$ llama-stack-client datasets register \
|
||||
--dataset-id "mmlu" \
|
||||
--provider-id "huggingface" \
|
||||
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
|
||||
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string", "chat_completion_input": {"type": "string"}}}'
|
||||
```
|
||||
|
||||
```bash
|
||||
$ llama-stack-client eval_tasks register \
|
||||
--eval-task-id meta-reference-mmlu \
|
||||
--provider-id meta-reference \
|
||||
--dataset-id mmlu \
|
||||
--scoring-functions basic::regex_parser_multiple_choice_answer
|
||||
```
|
||||
|
||||
3. Start Streamlit UI
|
||||
|
||||
```bash
|
||||
cd llama_stack/distribution/ui
|
||||
pip install -r requirements.txt
|
||||
streamlit run app.py
|
||||
```
|
||||
57
llama_stack/distribution/ui/app.py
Normal file
57
llama_stack/distribution/ui/app.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def main():
|
||||
# Evaluation pages
|
||||
application_evaluation_page = st.Page(
|
||||
"page/evaluations/app_eval.py",
|
||||
title="Evaluations (Scoring)",
|
||||
icon="📊",
|
||||
default=False,
|
||||
)
|
||||
native_evaluation_page = st.Page(
|
||||
"page/evaluations/native_eval.py",
|
||||
title="Evaluations (Generation + Scoring)",
|
||||
icon="📊",
|
||||
default=False,
|
||||
)
|
||||
|
||||
# Playground pages
|
||||
chat_page = st.Page(
|
||||
"page/playground/chat.py", title="Chat", icon="💬", default=True
|
||||
)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page(
|
||||
"page/distribution/resources.py", title="Resources", icon="🔍", default=False
|
||||
)
|
||||
provider_page = st.Page(
|
||||
"page/distribution/providers.py",
|
||||
title="API Providers",
|
||||
icon="🔍",
|
||||
default=False,
|
||||
)
|
||||
|
||||
pg = st.navigation(
|
||||
{
|
||||
"Playground": [
|
||||
chat_page,
|
||||
rag_page,
|
||||
application_evaluation_page,
|
||||
native_evaluation_page,
|
||||
],
|
||||
"Inspect": [provider_page, resources_page],
|
||||
},
|
||||
expanded=False,
|
||||
)
|
||||
pg.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
5
llama_stack/distribution/ui/modules/__init__.py
Normal file
5
llama_stack/distribution/ui/modules/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
36
llama_stack/distribution/ui/modules/api.py
Normal file
36
llama_stack/distribution/ui/modules/api.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_stack_client import LlamaStackClient
|
||||
|
||||
|
||||
class LlamaStackApi:
|
||||
def __init__(self):
|
||||
self.client = LlamaStackClient(
|
||||
base_url=os.environ.get("LLAMA_STACK_ENDPOINT", "http://localhost:5000"),
|
||||
provider_data={
|
||||
"fireworks_api_key": os.environ.get("FIREWORKS_API_KEY", ""),
|
||||
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
||||
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def run_scoring(
|
||||
self, row, scoring_function_ids: list[str], scoring_params: Optional[dict]
|
||||
):
|
||||
"""Run scoring on a single row"""
|
||||
if not scoring_params:
|
||||
scoring_params = {fn_id: None for fn_id in scoring_function_ids}
|
||||
return self.client.scoring.score(
|
||||
input_rows=[row], scoring_functions=scoring_params
|
||||
)
|
||||
|
||||
|
||||
llama_stack_api = LlamaStackApi()
|
||||
42
llama_stack/distribution/ui/modules/utils.py
Normal file
42
llama_stack/distribution/ui/modules/utils.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
|
||||
def process_dataset(file):
|
||||
if file is None:
|
||||
return "No file uploaded", None
|
||||
|
||||
try:
|
||||
# Determine file type and read accordingly
|
||||
file_ext = os.path.splitext(file.name)[1].lower()
|
||||
if file_ext == ".csv":
|
||||
df = pd.read_csv(file)
|
||||
elif file_ext in [".xlsx", ".xls"]:
|
||||
df = pd.read_excel(file)
|
||||
else:
|
||||
return "Unsupported file format. Please upload a CSV or Excel file.", None
|
||||
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Error processing file: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def data_url_from_file(file) -> str:
|
||||
file_content = file.getvalue()
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type = file.type
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
5
llama_stack/distribution/ui/page/__init__.py
Normal file
5
llama_stack/distribution/ui/page/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
19
llama_stack/distribution/ui/page/distribution/datasets.py
Normal file
19
llama_stack/distribution/ui/page/distribution/datasets.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def datasets():
|
||||
st.header("Datasets")
|
||||
|
||||
datasets_info = {
|
||||
d.identifier: d.to_dict() for d in llama_stack_api.client.datasets.list()
|
||||
}
|
||||
|
||||
selected_dataset = st.selectbox("Select a dataset", list(datasets_info.keys()))
|
||||
st.json(datasets_info[selected_dataset], expanded=True)
|
||||
22
llama_stack/distribution/ui/page/distribution/eval_tasks.py
Normal file
22
llama_stack/distribution/ui/page/distribution/eval_tasks.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def eval_tasks():
|
||||
# Eval Tasks Section
|
||||
st.header("Eval Tasks")
|
||||
|
||||
eval_tasks_info = {
|
||||
d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()
|
||||
}
|
||||
|
||||
selected_eval_task = st.selectbox(
|
||||
"Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect"
|
||||
)
|
||||
st.json(eval_tasks_info[selected_eval_task], expanded=True)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def memory_banks():
|
||||
st.header("Memory Banks")
|
||||
memory_banks_info = {
|
||||
m.identifier: m.to_dict() for m in llama_stack_api.client.memory_banks.list()
|
||||
}
|
||||
|
||||
if len(memory_banks_info) > 0:
|
||||
selected_memory_bank = st.selectbox(
|
||||
"Select a memory bank", list(memory_banks_info.keys())
|
||||
)
|
||||
st.json(memory_banks_info[selected_memory_bank])
|
||||
else:
|
||||
st.info("No memory banks found")
|
||||
19
llama_stack/distribution/ui/page/distribution/models.py
Normal file
19
llama_stack/distribution/ui/page/distribution/models.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def models():
|
||||
# Models Section
|
||||
st.header("Models")
|
||||
models_info = {
|
||||
m.identifier: m.to_dict() for m in llama_stack_api.client.models.list()
|
||||
}
|
||||
|
||||
selected_model = st.selectbox("Select a model", list(models_info.keys()))
|
||||
st.json(models_info[selected_model])
|
||||
20
llama_stack/distribution/ui/page/distribution/providers.py
Normal file
20
llama_stack/distribution/ui/page/distribution/providers.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def providers():
|
||||
st.header("🔍 API Providers")
|
||||
apis_providers_info = llama_stack_api.client.providers.list()
|
||||
# selected_api = st.selectbox("Select an API", list(apis_providers_info.keys()))
|
||||
for api in apis_providers_info.keys():
|
||||
st.markdown(f"###### {api}")
|
||||
st.dataframe([p.to_dict() for p in apis_providers_info[api]], width=500)
|
||||
|
||||
|
||||
providers()
|
||||
52
llama_stack/distribution/ui/page/distribution/resources.py
Normal file
52
llama_stack/distribution/ui/page/distribution/resources.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from page.distribution.datasets import datasets
|
||||
from page.distribution.eval_tasks import eval_tasks
|
||||
from page.distribution.memory_banks import memory_banks
|
||||
from page.distribution.models import models
|
||||
from page.distribution.scoring_functions import scoring_functions
|
||||
from page.distribution.shields import shields
|
||||
|
||||
from streamlit_option_menu import option_menu
|
||||
|
||||
|
||||
def resources_page():
|
||||
options = [
|
||||
"Models",
|
||||
"Memory Banks",
|
||||
"Shields",
|
||||
"Scoring Functions",
|
||||
"Datasets",
|
||||
"Eval Tasks",
|
||||
]
|
||||
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
|
||||
selected_resource = option_menu(
|
||||
None,
|
||||
options,
|
||||
icons=icons,
|
||||
orientation="horizontal",
|
||||
styles={
|
||||
"nav-link": {
|
||||
"font-size": "12px",
|
||||
},
|
||||
},
|
||||
)
|
||||
if selected_resource == "Eval Tasks":
|
||||
eval_tasks()
|
||||
elif selected_resource == "Memory Banks":
|
||||
memory_banks()
|
||||
elif selected_resource == "Datasets":
|
||||
datasets()
|
||||
elif selected_resource == "Models":
|
||||
models()
|
||||
elif selected_resource == "Scoring Functions":
|
||||
scoring_functions()
|
||||
elif selected_resource == "Shields":
|
||||
shields()
|
||||
|
||||
|
||||
resources_page()
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def scoring_functions():
|
||||
st.header("Scoring Functions")
|
||||
|
||||
scoring_functions_info = {
|
||||
s.identifier: s.to_dict()
|
||||
for s in llama_stack_api.client.scoring_functions.list()
|
||||
}
|
||||
|
||||
selected_scoring_function = st.selectbox(
|
||||
"Select a scoring function", list(scoring_functions_info.keys())
|
||||
)
|
||||
st.json(scoring_functions_info[selected_scoring_function], expanded=True)
|
||||
20
llama_stack/distribution/ui/page/distribution/shields.py
Normal file
20
llama_stack/distribution/ui/page/distribution/shields.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def shields():
|
||||
# Shields Section
|
||||
st.header("Shields")
|
||||
|
||||
shields_info = {
|
||||
s.identifier: s.to_dict() for s in llama_stack_api.client.shields.list()
|
||||
}
|
||||
|
||||
selected_shield = st.selectbox("Select a shield", list(shields_info.keys()))
|
||||
st.json(shields_info[selected_shield])
|
||||
5
llama_stack/distribution/ui/page/evaluations/__init__.py
Normal file
5
llama_stack/distribution/ui/page/evaluations/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
148
llama_stack/distribution/ui/page/evaluations/app_eval.py
Normal file
148
llama_stack/distribution/ui/page/evaluations/app_eval.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import process_dataset
|
||||
|
||||
|
||||
def application_evaluation_page():
|
||||
|
||||
st.set_page_config(page_title="Evaluations (Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Scoring)")
|
||||
|
||||
# File uploader
|
||||
uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx", "xls"])
|
||||
|
||||
if uploaded_file is None:
|
||||
st.error("No file uploaded")
|
||||
return
|
||||
|
||||
# Process uploaded file
|
||||
df = process_dataset(uploaded_file)
|
||||
if df is None:
|
||||
st.error("Error processing file")
|
||||
return
|
||||
|
||||
# Display dataset information
|
||||
st.success("Dataset loaded successfully!")
|
||||
|
||||
# Display dataframe preview
|
||||
st.subheader("Dataset Preview")
|
||||
st.dataframe(df)
|
||||
|
||||
# Select Scoring Functions to Run Evaluation On
|
||||
st.subheader("Select Scoring Functions")
|
||||
scoring_functions = llama_stack_api.client.scoring_functions.list()
|
||||
scoring_functions = {sf.identifier: sf for sf in scoring_functions}
|
||||
scoring_functions_names = list(scoring_functions.keys())
|
||||
selected_scoring_functions = st.multiselect(
|
||||
"Choose one or more scoring functions",
|
||||
options=scoring_functions_names,
|
||||
help="Choose one or more scoring functions.",
|
||||
)
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [m.identifier for m in available_models]
|
||||
|
||||
scoring_params = {}
|
||||
if selected_scoring_functions:
|
||||
st.write("Selected:")
|
||||
for scoring_fn_id in selected_scoring_functions:
|
||||
scoring_fn = scoring_functions[scoring_fn_id]
|
||||
st.write(f"- **{scoring_fn_id}**: {scoring_fn.description}")
|
||||
new_params = None
|
||||
if scoring_fn.params:
|
||||
new_params = {}
|
||||
for param_name, param_value in scoring_fn.params.to_dict().items():
|
||||
if param_name == "type":
|
||||
new_params[param_name] = param_value
|
||||
continue
|
||||
|
||||
if param_name == "judge_model":
|
||||
value = st.selectbox(
|
||||
f"Select **{param_name}** for {scoring_fn_id}",
|
||||
options=available_models,
|
||||
index=0,
|
||||
key=f"{scoring_fn_id}_{param_name}",
|
||||
)
|
||||
new_params[param_name] = value
|
||||
else:
|
||||
value = st.text_area(
|
||||
f"Enter value for **{param_name}** in {scoring_fn_id} in valid JSON format",
|
||||
value=json.dumps(param_value, indent=2),
|
||||
height=80,
|
||||
)
|
||||
try:
|
||||
new_params[param_name] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
st.error(
|
||||
f"Invalid JSON for **{param_name}** in {scoring_fn_id}"
|
||||
)
|
||||
|
||||
st.json(new_params)
|
||||
scoring_params[scoring_fn_id] = new_params
|
||||
|
||||
# Add run evaluation button & slider
|
||||
total_rows = len(df)
|
||||
num_rows = st.slider("Number of rows to evaluate", 1, total_rows, total_rows)
|
||||
|
||||
if st.button("Run Evaluation"):
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = df.to_dict(orient="records")
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
# Create separate containers for progress text and results
|
||||
progress_text_container = st.empty()
|
||||
results_container = st.empty()
|
||||
output_res = {}
|
||||
for i, r in enumerate(rows):
|
||||
# Update progress
|
||||
progress = i / len(rows)
|
||||
progress_bar.progress(progress, text=progress_text)
|
||||
|
||||
# Run evaluation for current row
|
||||
score_res = llama_stack_api.run_scoring(
|
||||
r,
|
||||
scoring_function_ids=selected_scoring_functions,
|
||||
scoring_params=scoring_params,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(r[k])
|
||||
|
||||
for fn_id in selected_scoring_functions:
|
||||
if fn_id not in output_res:
|
||||
output_res[fn_id] = []
|
||||
output_res[fn_id].append(score_res.results[fn_id].score_rows[0])
|
||||
|
||||
# Display current row results using separate containers
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i+1}/{len(rows)})"
|
||||
)
|
||||
results_container.json(
|
||||
score_res.to_json(),
|
||||
expanded=2,
|
||||
)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
|
||||
# Display results in dataframe
|
||||
if output_res:
|
||||
output_df = pd.DataFrame(output_res)
|
||||
st.subheader("Evaluation Results")
|
||||
st.dataframe(output_df)
|
||||
|
||||
|
||||
application_evaluation_page()
|
||||
257
llama_stack/distribution/ui/page/evaluations/native_eval.py
Normal file
257
llama_stack/distribution/ui/page/evaluations/native_eval.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import pandas as pd
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
|
||||
def select_eval_task_1():
|
||||
# Select Eval Tasks
|
||||
st.subheader("1. Choose An Eval Task")
|
||||
eval_tasks = llama_stack_api.client.eval_tasks.list()
|
||||
eval_tasks = {et.identifier: et for et in eval_tasks}
|
||||
eval_tasks_names = list(eval_tasks.keys())
|
||||
selected_eval_task = st.selectbox(
|
||||
"Choose an eval task.",
|
||||
options=eval_tasks_names,
|
||||
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
|
||||
)
|
||||
with st.expander("View Eval Task"):
|
||||
st.json(eval_tasks[selected_eval_task], expanded=True)
|
||||
|
||||
st.session_state["selected_eval_task"] = selected_eval_task
|
||||
st.session_state["eval_tasks"] = eval_tasks
|
||||
if st.button("Confirm", key="confirm_1"):
|
||||
st.session_state["selected_eval_task_1_next"] = True
|
||||
|
||||
|
||||
def define_eval_candidate_2():
|
||||
if not st.session_state.get("selected_eval_task_1_next", None):
|
||||
return
|
||||
|
||||
st.subheader("2. Define Eval Candidate")
|
||||
st.info(
|
||||
"""
|
||||
Define the configurations for the evaluation candidate model or agent used for generation.
|
||||
Select "model" if you want to run generation with inference API, or "agent" if you want to run generation with agent API through specifying AgentConfig.
|
||||
"""
|
||||
)
|
||||
with st.expander("Define Eval Candidate", expanded=True):
|
||||
# Define Eval Candidate
|
||||
candidate_type = st.radio("Candidate Type", ["model", "agent"])
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
|
||||
# Sampling Parameters
|
||||
st.markdown("##### Sampling Parameters")
|
||||
strategy = st.selectbox(
|
||||
"Strategy",
|
||||
["greedy", "top_p", "top_k"],
|
||||
index=0,
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
repetition_penalty = st.slider(
|
||||
"Repetition Penalty",
|
||||
min_value=1.0,
|
||||
max_value=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
if candidate_type == "model":
|
||||
eval_candidate = {
|
||||
"type": "model",
|
||||
"model": selected_model,
|
||||
"sampling_params": {
|
||||
"strategy": strategy,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
}
|
||||
elif candidate_type == "agent":
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful AI assistant.",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
tools_json = st.text_area(
|
||||
"Tools Configuration (JSON)",
|
||||
value=json.dumps(
|
||||
[
|
||||
{
|
||||
"type": "brave_search",
|
||||
"engine": "brave",
|
||||
"api_key": "ENTER_BRAVE_API_KEY_HERE",
|
||||
}
|
||||
]
|
||||
),
|
||||
help="Enter tool configurations in JSON format. Each tool should have a name, description, and parameters.",
|
||||
height=200,
|
||||
)
|
||||
try:
|
||||
tools = json.loads(tools_json)
|
||||
except json.JSONDecodeError:
|
||||
st.error("Invalid JSON format for tools configuration")
|
||||
tools = []
|
||||
eval_candidate = {
|
||||
"type": "agent",
|
||||
"config": {
|
||||
"model": selected_model,
|
||||
"instructions": system_prompt,
|
||||
"tools": tools,
|
||||
"tool_choice": "auto",
|
||||
"tool_prompt_format": "json",
|
||||
"input_shields": [],
|
||||
"output_shields": [],
|
||||
"enable_session_persistence": False,
|
||||
},
|
||||
}
|
||||
st.session_state["eval_candidate"] = eval_candidate
|
||||
|
||||
if st.button("Confirm", key="confirm_2"):
|
||||
st.session_state["selected_eval_candidate_2_next"] = True
|
||||
|
||||
|
||||
def run_evaluation_3():
|
||||
if not st.session_state.get("selected_eval_candidate_2_next", None):
|
||||
return
|
||||
|
||||
st.subheader("3. Run Evaluation")
|
||||
# Add info box to explain configurations being used
|
||||
st.info(
|
||||
"""
|
||||
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
|
||||
"""
|
||||
)
|
||||
selected_eval_task = st.session_state["selected_eval_task"]
|
||||
eval_tasks = st.session_state["eval_tasks"]
|
||||
eval_candidate = st.session_state["eval_candidate"]
|
||||
|
||||
dataset_id = eval_tasks[selected_eval_task].dataset_id
|
||||
rows = llama_stack_api.client.datasetio.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
total_rows = len(rows.rows)
|
||||
# Add number of examples control
|
||||
num_rows = st.number_input(
|
||||
"Number of Examples to Evaluate",
|
||||
min_value=1,
|
||||
max_value=total_rows,
|
||||
value=5,
|
||||
help="Number of examples from the dataset to evaluate. ",
|
||||
)
|
||||
|
||||
eval_task_config = {
|
||||
"type": "benchmark",
|
||||
"eval_candidate": eval_candidate,
|
||||
"scoring_params": {},
|
||||
}
|
||||
|
||||
with st.expander("View Evaluation Task", expanded=True):
|
||||
st.json(eval_tasks[selected_eval_task], expanded=True)
|
||||
with st.expander("View Evaluation Task Configuration", expanded=True):
|
||||
st.json(eval_task_config, expanded=True)
|
||||
|
||||
# Add run button and handle evaluation
|
||||
if st.button("Run Evaluation"):
|
||||
|
||||
progress_text = "Running evaluation..."
|
||||
progress_bar = st.progress(0, text=progress_text)
|
||||
rows = rows.rows
|
||||
if num_rows < total_rows:
|
||||
rows = rows[:num_rows]
|
||||
|
||||
# Create separate containers for progress text and results
|
||||
progress_text_container = st.empty()
|
||||
results_container = st.empty()
|
||||
output_res = {}
|
||||
for i, r in enumerate(rows):
|
||||
# Update progress
|
||||
progress = i / len(rows)
|
||||
progress_bar.progress(progress, text=progress_text)
|
||||
# Run evaluation for current row
|
||||
eval_res = llama_stack_api.client.eval.evaluate_rows(
|
||||
task_id=selected_eval_task,
|
||||
input_rows=[r],
|
||||
scoring_functions=eval_tasks[selected_eval_task].scoring_functions,
|
||||
task_config=eval_task_config,
|
||||
)
|
||||
|
||||
for k in r.keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(r[k])
|
||||
|
||||
for k in eval_res.generations[0].keys():
|
||||
if k not in output_res:
|
||||
output_res[k] = []
|
||||
output_res[k].append(eval_res.generations[0][k])
|
||||
|
||||
for scoring_fn in eval_tasks[selected_eval_task].scoring_functions:
|
||||
if scoring_fn not in output_res:
|
||||
output_res[scoring_fn] = []
|
||||
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
|
||||
|
||||
progress_text_container.write(
|
||||
f"Expand to see current processed result ({i+1}/{len(rows)})"
|
||||
)
|
||||
results_container.json(eval_res, expanded=2)
|
||||
|
||||
progress_bar.progress(1.0, text="Evaluation complete!")
|
||||
# Display results in dataframe
|
||||
if output_res:
|
||||
output_df = pd.DataFrame(output_res)
|
||||
st.subheader("Evaluation Results")
|
||||
st.dataframe(output_df)
|
||||
|
||||
|
||||
def native_evaluation_page():
|
||||
|
||||
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
|
||||
st.title("📊 Evaluations (Generation + Scoring)")
|
||||
|
||||
select_eval_task_1()
|
||||
define_eval_candidate_2()
|
||||
run_evaluation_3()
|
||||
|
||||
|
||||
native_evaluation_page()
|
||||
5
llama_stack/distribution/ui/page/playground/__init__.py
Normal file
5
llama_stack/distribution/ui/page/playground/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
123
llama_stack/distribution/ui/page/playground/chat.py
Normal file
123
llama_stack/distribution/ui/page/playground/chat.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from modules.api import llama_stack_api
|
||||
|
||||
# Sidebar configurations
|
||||
with st.sidebar:
|
||||
st.header("Configuration")
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
|
||||
max_tokens = st.slider(
|
||||
"Max Tokens",
|
||||
min_value=0,
|
||||
max_value=4096,
|
||||
value=512,
|
||||
step=1,
|
||||
help="The maximum number of tokens to generate",
|
||||
)
|
||||
|
||||
repetition_penalty = st.slider(
|
||||
"Repetition Penalty",
|
||||
min_value=1.0,
|
||||
max_value=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
help="Controls the likelihood for generating the same word or phrase multiple times in the same sentence or paragraph. 1 implies no penalty, 2 will strongly discourage model to repeat words or phrases.",
|
||||
)
|
||||
|
||||
stream = st.checkbox("Stream", value=True)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful AI assistant.",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
|
||||
# Main chat interface
|
||||
st.title("🦙 Chat")
|
||||
|
||||
|
||||
# Initialize chat history
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat messages
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Example: What is Llama Stack?"):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
|
||||
response = llama_stack_api.client.inference.chat_completion(
|
||||
messages=[
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
model_id=selected_model,
|
||||
stream=stream,
|
||||
sampling_params={
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_tokens": max_tokens,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
},
|
||||
)
|
||||
|
||||
if stream:
|
||||
for chunk in response:
|
||||
if chunk.event.event_type == "progress":
|
||||
full_response += chunk.event.delta
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
else:
|
||||
full_response = response
|
||||
message_placeholder.markdown(full_response.completion_message.content)
|
||||
|
||||
st.session_state.messages.append(
|
||||
{"role": "assistant", "content": full_response}
|
||||
)
|
||||
188
llama_stack/distribution/ui/page/playground/rag.py
Normal file
188
llama_stack/distribution/ui/page/playground/rag.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client.lib.agents.agent import Agent
|
||||
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||
from llama_stack_client.types.memory_insert_params import Document
|
||||
|
||||
from modules.api import llama_stack_api
|
||||
from modules.utils import data_url_from_file
|
||||
|
||||
|
||||
def rag_chat_page():
|
||||
st.title("🦙 RAG")
|
||||
|
||||
with st.sidebar:
|
||||
# File/Directory Upload Section
|
||||
st.subheader("Upload Documents")
|
||||
uploaded_files = st.file_uploader(
|
||||
"Upload file(s) or directory",
|
||||
accept_multiple_files=True,
|
||||
type=["txt", "pdf", "doc", "docx"], # Add more file types as needed
|
||||
)
|
||||
# Process uploaded files
|
||||
if uploaded_files:
|
||||
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
||||
# Add memory bank name input field
|
||||
memory_bank_name = st.text_input(
|
||||
"Memory Bank Name",
|
||||
value="rag_bank",
|
||||
help="Enter a unique identifier for this memory bank",
|
||||
)
|
||||
if st.button("Create Memory Bank"):
|
||||
documents = [
|
||||
Document(
|
||||
document_id=uploaded_file.name,
|
||||
content=data_url_from_file(uploaded_file),
|
||||
)
|
||||
for i, uploaded_file in enumerate(uploaded_files)
|
||||
]
|
||||
|
||||
providers = llama_stack_api.client.providers.list()
|
||||
llama_stack_api.client.memory_banks.register(
|
||||
memory_bank_id=memory_bank_name, # Use the user-provided name
|
||||
params={
|
||||
"embedding_model": "all-MiniLM-L6-v2",
|
||||
"chunk_size_in_tokens": 512,
|
||||
"overlap_size_in_tokens": 64,
|
||||
},
|
||||
provider_id=providers["memory"][0].provider_id,
|
||||
)
|
||||
|
||||
# insert documents using the custom bank name
|
||||
llama_stack_api.client.memory.insert(
|
||||
bank_id=memory_bank_name, # Use the user-provided name
|
||||
documents=documents,
|
||||
)
|
||||
st.success("Memory bank created successfully!")
|
||||
|
||||
st.subheader("Configure Agent")
|
||||
# select memory banks
|
||||
memory_banks = llama_stack_api.client.memory_banks.list()
|
||||
memory_banks = [bank.identifier for bank in memory_banks]
|
||||
selected_memory_banks = st.multiselect(
|
||||
"Select Memory Banks",
|
||||
memory_banks,
|
||||
)
|
||||
memory_bank_configs = [
|
||||
{"bank_id": bank_id, "type": "vector"} for bank_id in selected_memory_banks
|
||||
]
|
||||
|
||||
available_models = llama_stack_api.client.models.list()
|
||||
available_models = [model.identifier for model in available_models]
|
||||
selected_model = st.selectbox(
|
||||
"Choose a model",
|
||||
available_models,
|
||||
index=0,
|
||||
)
|
||||
system_prompt = st.text_area(
|
||||
"System Prompt",
|
||||
value="You are a helpful assistant. ",
|
||||
help="Initial instructions given to the AI to set its behavior and context",
|
||||
)
|
||||
temperature = st.slider(
|
||||
"Temperature",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.0,
|
||||
step=0.1,
|
||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||
)
|
||||
|
||||
top_p = st.slider(
|
||||
"Top P",
|
||||
min_value=0.0,
|
||||
max_value=1.0,
|
||||
value=0.95,
|
||||
step=0.1,
|
||||
)
|
||||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
|
||||
# Chat Interface
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat history
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
selected_model = llama_stack_api.client.models.list()[0].identifier
|
||||
|
||||
agent_config = AgentConfig(
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": "greedy",
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
},
|
||||
tools=[
|
||||
{
|
||||
"type": "memory",
|
||||
"memory_bank_configs": memory_bank_configs,
|
||||
"query_generator_config": {"type": "default", "sep": " "},
|
||||
"max_tokens_in_context": 4096,
|
||||
"max_chunks": 10,
|
||||
}
|
||||
],
|
||||
tool_choice="auto",
|
||||
tool_prompt_format="json",
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
||||
agent = Agent(llama_stack_api.client, agent_config)
|
||||
session_id = agent.create_session("rag-session")
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
# Add user message to chat history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
response = agent.create_turn(
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt,
|
||||
}
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Display assistant response
|
||||
with st.chat_message("assistant"):
|
||||
retrieval_message_placeholder = st.empty()
|
||||
message_placeholder = st.empty()
|
||||
full_response = ""
|
||||
retrieval_response = ""
|
||||
for log in EventLogger().log(response):
|
||||
log.print()
|
||||
if log.role == "memory_retrieval":
|
||||
retrieval_response += log.content.replace("====", "").strip()
|
||||
retrieval_message_placeholder.info(retrieval_response)
|
||||
else:
|
||||
full_response += log.content
|
||||
message_placeholder.markdown(full_response + "▌")
|
||||
message_placeholder.markdown(full_response)
|
||||
|
||||
st.session_state.messages.append(
|
||||
{"role": "assistant", "content": full_response}
|
||||
)
|
||||
|
||||
|
||||
rag_chat_page()
|
||||
4
llama_stack/distribution/ui/requirements.txt
Normal file
4
llama_stack/distribution/ui/requirements.txt
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
streamlit
|
||||
pandas
|
||||
llama-stack-client>=0.0.55
|
||||
streamlit-option-menu
|
||||
|
|
@ -4,11 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from .config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
|
||||
def model_local_dir(descriptor: str) -> str:
|
||||
path = os.path.join(DEFAULT_CHECKPOINT_DIR, descriptor)
|
||||
return path.replace(":", "-")
|
||||
return str(Path(DEFAULT_CHECKPOINT_DIR) / (descriptor.replace(":", "-")))
|
||||
|
|
|
|||
|
|
@ -54,8 +54,6 @@ class ShieldsProtocolPrivate(Protocol):
|
|||
|
||||
|
||||
class MemoryBanksProtocolPrivate(Protocol):
|
||||
async def list_memory_banks(self) -> List[MemoryBank]: ...
|
||||
|
||||
async def register_memory_bank(self, memory_bank: MemoryBank) -> None: ...
|
||||
|
||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...
|
||||
|
|
@ -64,6 +62,8 @@ class MemoryBanksProtocolPrivate(Protocol):
|
|||
class DatasetsProtocolPrivate(Protocol):
|
||||
async def register_dataset(self, dataset: Dataset) -> None: ...
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None: ...
|
||||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFn]: ...
|
||||
|
|
@ -201,10 +201,13 @@ API responses, specify the adapter here.
|
|||
return self.adapter.provider_data_validator
|
||||
|
||||
|
||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
||||
def remote_provider_spec(
|
||||
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||
) -> RemoteProviderSpec:
|
||||
return RemoteProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"remote::{adapter.adapter_type}",
|
||||
config_class=adapter.config_class,
|
||||
adapter=adapter,
|
||||
api_dependencies=api_dependencies or [],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import secrets
|
||||
import shutil
|
||||
import string
|
||||
import tempfile
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, List, Tuple
|
||||
|
|
@ -57,6 +55,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
agent_id: str,
|
||||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
memory_api: Memory,
|
||||
memory_banks_api: MemoryBanks,
|
||||
|
|
@ -65,14 +64,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.memory_api = memory_api
|
||||
self.memory_banks_api = memory_banks_api
|
||||
self.safety_api = safety_api
|
||||
self.storage = AgentPersistence(agent_id, persistence_store)
|
||||
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
||||
builtin_tools = []
|
||||
for tool_defn in agent_config.tools:
|
||||
if isinstance(tool_defn, WolframAlphaToolDefinition):
|
||||
|
|
@ -103,9 +101,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def __del__(self):
|
||||
shutil.rmtree(self.tempdir)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> List[Message]:
|
||||
messages = []
|
||||
|
||||
|
|
@ -113,7 +108,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# May be this should be a parameter of the agentic instance
|
||||
# that can define its behavior in a custom way
|
||||
for m in turn.input_messages:
|
||||
msg = m.copy()
|
||||
msg = m.model_copy()
|
||||
if isinstance(msg, UserMessage):
|
||||
msg.context = None
|
||||
messages.append(msg)
|
||||
|
|
@ -144,87 +139,91 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
@tracing.span("create_and_execute_turn")
|
||||
async def create_and_execute_turn(
|
||||
self, request: AgentTurnCreateRequest
|
||||
) -> AsyncGenerator:
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
with tracing.span("create_and_execute_turn") as span:
|
||||
span.set_attribute("session_id", request.session_id)
|
||||
span.set_attribute("agent_id", self.agent_id)
|
||||
span.set_attribute("request", request.model_dump_json())
|
||||
assert request.stream is True, "Non-streaming not supported"
|
||||
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
session_info = await self.storage.get_session_info(request.session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {request.session_id} not found")
|
||||
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
turns = await self.storage.get_session_turns(request.session_id)
|
||||
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
for i, turn in enumerate(turns):
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
||||
messages.extend(request.messages)
|
||||
messages.extend(request.messages)
|
||||
|
||||
turn_id = str(uuid.uuid4())
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
turn_id = str(uuid.uuid4())
|
||||
span.set_attribute("turn_id", turn_id)
|
||||
start_time = datetime.now()
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnStartPayload(
|
||||
turn_id=turn_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
steps = []
|
||||
output_message = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
attachments=request.attachments or [],
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
stream=request.stream,
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
log.info(
|
||||
f"{chunk.role.capitalize()}: {chunk.content}",
|
||||
)
|
||||
output_message = chunk
|
||||
continue
|
||||
|
||||
yield chunk
|
||||
assert isinstance(
|
||||
chunk, AgentTurnResponseStreamChunk
|
||||
), f"Unexpected type {type(chunk)}"
|
||||
event = chunk.event
|
||||
if (
|
||||
event.payload.event_type
|
||||
== AgentTurnResponseEventType.step_complete.value
|
||||
):
|
||||
steps.append(event.payload.step_details)
|
||||
|
||||
assert output_message is not None
|
||||
yield chunk
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
assert output_message is not None
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=request.messages,
|
||||
output_message=output_message,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnCompletePayload(
|
||||
turn=turn,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
yield chunk
|
||||
yield chunk
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
@ -273,7 +272,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield final_response
|
||||
|
||||
@tracing.span("run_shields")
|
||||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
|
|
@ -281,23 +279,46 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
shields: List[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
if len(shields) == 0:
|
||||
return
|
||||
with tracing.span("run_shields") as span:
|
||||
span.set_attribute("input", [m.model_dump_json() for m in messages])
|
||||
if len(shields) == 0:
|
||||
span.set_attribute("output", "no shields")
|
||||
return
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
step_id = str(uuid.uuid4())
|
||||
try:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepStartPayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_id=step_id,
|
||||
metadata=dict(touchpoint=touchpoint),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", e.violation.model_dump_json())
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
|
@ -305,30 +326,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=e.violation,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield CompletionMessage(
|
||||
content=str(e),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
yield False
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
step_type=StepType.shield_call.value,
|
||||
step_details=ShieldCallStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
violation=None,
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
span.set_attribute("output", "no violations")
|
||||
|
||||
async def _run(
|
||||
self,
|
||||
|
|
@ -356,10 +359,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
# TODO: find older context from the session and either replace it
|
||||
# or append with a sliding window. this is really a very simplistic implementation
|
||||
with tracing.span("retrieve_rag_context"):
|
||||
with tracing.span("retrieve_rag_context") as span:
|
||||
rag_context, bank_ids = await self._retrieve_context(
|
||||
session_id, input_messages, attachments
|
||||
)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute("output", rag_context)
|
||||
span.set_attribute("bank_ids", bank_ids)
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -396,11 +404,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter = 0
|
||||
while True:
|
||||
msg = input_messages[-1]
|
||||
if len(str(msg)) > 1000:
|
||||
msg_str = f"{str(msg)[:500]}...<more>...{str(msg)[-500:]}"
|
||||
else:
|
||||
msg_str = str(msg)
|
||||
log.info(f"{msg_str}")
|
||||
|
||||
step_id = str(uuid.uuid4())
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
|
@ -416,7 +419,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
content = ""
|
||||
stop_reason = None
|
||||
|
||||
with tracing.span("inference"):
|
||||
with tracing.span("inference") as span:
|
||||
async for chunk in await self.inference_api.chat_completion(
|
||||
self.agent_config.model,
|
||||
input_messages,
|
||||
|
|
@ -436,14 +439,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if isinstance(delta, ToolCallDelta):
|
||||
if delta.parse_status == ToolCallParseStatus.success:
|
||||
tool_calls.append(delta.content)
|
||||
|
||||
if stream:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta="",
|
||||
text_delta="",
|
||||
tool_call_delta=delta,
|
||||
)
|
||||
)
|
||||
|
|
@ -457,7 +459,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
payload=AgentTurnResponseStepProgressPayload(
|
||||
step_type=StepType.inference.value,
|
||||
step_id=step_id,
|
||||
model_response_text_delta=event.delta,
|
||||
text_delta=event.delta,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -466,6 +468,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
if event.stop_reason is not None:
|
||||
stop_reason = event.stop_reason
|
||||
span.set_attribute("stop_reason", stop_reason)
|
||||
span.set_attribute(
|
||||
"input", [m.model_dump_json() for m in input_messages]
|
||||
)
|
||||
span.set_attribute(
|
||||
"output", f"content: {content} tool_calls: {tool_calls}"
|
||||
)
|
||||
|
||||
stop_reason = stop_reason or StopReason.out_of_tokens
|
||||
|
||||
|
|
@ -549,7 +558,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
with tracing.span("tool_execution"):
|
||||
with tracing.span(
|
||||
"tool_execution",
|
||||
{
|
||||
"tool_name": tool_call.tool_name,
|
||||
"input": message.model_dump_json(),
|
||||
},
|
||||
) as span:
|
||||
result_messages = await execute_tool_call_maybe(
|
||||
self.tools_dict,
|
||||
[message],
|
||||
|
|
@ -558,6 +573,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
len(result_messages) == 1
|
||||
), "Currently not supporting multiple messages"
|
||||
result_message = result_messages[0]
|
||||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
|
|
|
|||
|
|
@ -6,9 +6,13 @@
|
|||
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
|
|
@ -40,10 +44,20 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.memory_banks_api = memory_banks_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
||||
# check if "bwrap" is available
|
||||
if not shutil.which("bwrap"):
|
||||
print(
|
||||
colored(
|
||||
"Warning: `bwrap` is not available. Code interpreter tool will not work correctly.",
|
||||
"yellow",
|
||||
)
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
|
|
@ -52,7 +66,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_config.json(),
|
||||
value=agent_config.model_dump_json(),
|
||||
)
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
|
|
@ -82,6 +96,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
tempdir=self.tempdir,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
memory_api=self.memory_api,
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class AgentPersistence:
|
|||
)
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
|
|
@ -60,13 +60,13 @@ class AgentPersistence:
|
|||
session_info.memory_bank_id = bank_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.json(),
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.json(),
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> List[Turn]:
|
||||
|
|
|
|||
5
llama_stack/providers/inline/datasetio/__init__.py
Normal file
5
llama_stack/providers/inline/datasetio/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -3,14 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import pandas
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
||||
|
|
@ -97,6 +100,9 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
dataset_impl=dataset_impl,
|
||||
)
|
||||
|
||||
async def unregister_dataset(self, dataset_id: str) -> None:
|
||||
del self.dataset_infos[dataset_id]
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
@ -128,3 +134,41 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
total_count=len(rows),
|
||||
next_page_token=str(end),
|
||||
)
|
||||
|
||||
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
||||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
if dataset_info is None:
|
||||
raise ValueError(f"Dataset with id {dataset_id} not found")
|
||||
|
||||
dataset_impl = dataset_info.dataset_impl
|
||||
dataset_impl.load()
|
||||
|
||||
new_rows_df = pandas.DataFrame(rows)
|
||||
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
|
||||
dataset_impl.df = pandas.concat(
|
||||
[dataset_impl.df, new_rows_df], ignore_index=True
|
||||
)
|
||||
|
||||
url = str(dataset_info.dataset_def.url)
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if parsed_url.scheme == "file" or not parsed_url.scheme:
|
||||
file_path = parsed_url.path
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
dataset_impl.df.to_csv(file_path, index=False)
|
||||
elif parsed_url.scheme == "data":
|
||||
# For data URLs, we need to update the base64-encoded content
|
||||
if not parsed_url.path.startswith("text/csv;base64,"):
|
||||
raise ValueError("Data URL must be a base64-encoded CSV")
|
||||
|
||||
csv_buffer = dataset_impl.df.to_csv(index=False)
|
||||
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
dataset_info.dataset_def.url = URL(
|
||||
uri=f"data:text/csv;base64,{base64_content}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
|
||||
)
|
||||
|
|
|
|||
5
llama_stack/providers/inline/eval/__init__.py
Normal file
5
llama_stack/providers/inline/eval/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -3,12 +3,13 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class MetaReferenceEvalConfig(BaseModel):
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from tqdm import tqdm
|
||||
|
||||
from .....apis.common.job_types import Job
|
||||
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
||||
|
|
@ -17,7 +19,6 @@ from llama_stack.apis.inference import Inference
|
|||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from tqdm import tqdm
|
||||
|
||||
from .config import MetaReferenceEvalConfig
|
||||
|
||||
|
|
@ -72,7 +73,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
|||
key = f"{EVAL_TASKS_PREFIX}{task_def.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=task_def.json(),
|
||||
value=task_def.model_dump_json(),
|
||||
)
|
||||
self.eval_tasks[task_def.identifier] = task_def
|
||||
|
||||
|
|
|
|||
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
convert_image_media_to_url,
|
||||
request_has_media,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .generation import Llama
|
||||
from .model_parallel import LlamaModelParallelGenerator
|
||||
|
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
|||
SEMAPHORE = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
||||
class MetaReferenceInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||
self.config = config
|
||||
model = resolve_model(config.model)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model_registry_helper = ModelRegistryHelper(
|
||||
[
|
||||
build_model_alias(
|
||||
model.descriptor(),
|
||||
|
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
)
|
||||
],
|
||||
)
|
||||
if model is None:
|
||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||
self.model = model
|
||||
# verify that the checkpoint actually is for this model lol
|
||||
|
||||
|
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
model = await self.model_registry_helper.register_model(model)
|
||||
if model.model_type == ModelType.embedding_model:
|
||||
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
|
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
|||
for x in impl():
|
||||
yield x
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
model_id: str,
|
||||
contents: List[InterleavedTextMedia],
|
||||
) -> EmbeddingsResponse:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
async def request_with_localized_media(
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: SentenceTransformersInferenceConfig,
|
||||
_deps,
|
||||
):
|
||||
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||
|
||||
impl = SentenceTransformersInferenceImpl(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -7,6 +7,4 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class OpenTelemetryConfig(BaseModel):
|
||||
jaeger_host: str = "localhost"
|
||||
jaeger_port: int = 6831
|
||||
class SentenceTransformersInferenceConfig(BaseModel): ...
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionResponse,
|
||||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
)
|
||||
from .config import SentenceTransformersInferenceConfig
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SentenceTransformersInferenceImpl(
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
Inference,
|
||||
ModelsProtocolPrivate,
|
||||
):
|
||||
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model_id: str,
|
||||
content: str,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||
raise ValueError("Sentence transformers don't support completion")
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise ValueError("Sentence transformers don't support chat completion")
|
||||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import ConsoleConfig
|
||||
from .config import ChromaInlineImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConsoleConfig, _deps):
|
||||
from .console import ConsoleTelemetryImpl
|
||||
async def get_provider_impl(config: ChromaInlineImplConfig, _deps):
|
||||
from llama_stack.providers.remote.memory.chroma.chroma import ChromaMemoryAdapter
|
||||
|
||||
impl = ConsoleTelemetryImpl(config)
|
||||
impl = ChromaMemoryAdapter(config)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -4,18 +4,14 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from llama_models.schema_utils import json_schema_type
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LogFormat(Enum):
|
||||
TEXT = "text"
|
||||
JSON = "json"
|
||||
class ChromaInlineImplConfig(BaseModel):
|
||||
db_path: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ConsoleConfig(BaseModel):
|
||||
log_format: LogFormat = LogFormat.TEXT
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
return {"db_path": "{env.CHROMADB_PATH}"}
|
||||
|
|
@ -4,16 +4,19 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from .config import FaissImplConfig
|
||||
|
||||
|
||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
||||
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||
from .faiss import FaissMemoryImpl
|
||||
|
||||
assert isinstance(
|
||||
config, FaissImplConfig
|
||||
), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = FaissMemoryImpl(config)
|
||||
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -19,21 +19,20 @@ from numpy.typing import NDArray
|
|||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.memory import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
ALL_MINILM_L6_V2_DIMENSION,
|
||||
BankWithIndex,
|
||||
EmbeddingIndex,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
||||
from .config import FaissImplConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
||||
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||
|
||||
|
||||
class FaissIndex(EmbeddingIndex):
|
||||
|
|
@ -57,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
|||
if not self.kvstore:
|
||||
return
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
stored_data = await self.kvstore.get(index_key)
|
||||
|
||||
if stored_data:
|
||||
|
|
@ -80,21 +79,31 @@ class FaissIndex(EmbeddingIndex):
|
|||
np.savetxt(buffer, np_index)
|
||||
data = {
|
||||
"id_by_index": self.id_by_index,
|
||||
"chunk_by_index": {k: v.json() for k, v in self.chunk_by_index.items()},
|
||||
"chunk_by_index": {
|
||||
k: v.model_dump_json() for k, v in self.chunk_by_index.items()
|
||||
},
|
||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||
}
|
||||
|
||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
||||
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||
|
||||
async def delete(self):
|
||||
if not self.kvstore or not self.bank_id:
|
||||
return
|
||||
|
||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
||||
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||
|
||||
@tracing.span(name="add_chunks")
|
||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||
# Add dimension check
|
||||
embedding_dim = (
|
||||
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||
)
|
||||
if embedding_dim != self.index.d:
|
||||
raise ValueError(
|
||||
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||
)
|
||||
|
||||
indexlen = len(self.id_by_index)
|
||||
for i, chunk in enumerate(chunks):
|
||||
self.chunk_by_index[indexlen + i] = chunk
|
||||
|
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
|
||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||
def __init__(self, config: FaissImplConfig) -> None:
|
||||
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||
self.config = config
|
||||
self.inference_api = inference_api
|
||||
self.cache = {}
|
||||
self.kvstore = None
|
||||
|
||||
|
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
for bank_data in stored_banks:
|
||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||
index = BankWithIndex(
|
||||
bank=bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
||||
bank,
|
||||
await FaissIndex.create(
|
||||
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[bank.identifier] = index
|
||||
|
||||
|
|
@ -162,17 +173,17 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
|||
key = f"{MEMORY_BANKS_PREFIX}{memory_bank.identifier}"
|
||||
await self.kvstore.set(
|
||||
key=key,
|
||||
value=memory_bank.json(),
|
||||
value=memory_bank.model_dump_json(),
|
||||
)
|
||||
|
||||
# Store in cache
|
||||
index = BankWithIndex(
|
||||
bank=memory_bank,
|
||||
index=await FaissIndex.create(
|
||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
||||
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||
memory_bank,
|
||||
await FaissIndex.create(
|
||||
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||
),
|
||||
self.inference_api,
|
||||
)
|
||||
self.cache[memory_bank.identifier] = index
|
||||
|
||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||
return [i.bank for i in self.cache.values()]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from .config import LogFormat
|
||||
|
||||
|
|
@ -49,8 +49,27 @@ class ConsoleTelemetryImpl(Telemetry):
|
|||
if formatted:
|
||||
print(formatted)
|
||||
|
||||
async def get_trace(self, trace_id: str) -> Trace:
|
||||
raise NotImplementedError()
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
raise NotImplementedError("Console telemetry does not support trace querying")
|
||||
|
||||
async def get_spans(
|
||||
self,
|
||||
span_id: str,
|
||||
attribute_conditions: Optional[List[QueryCondition]] = None,
|
||||
attribute_keys_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> SpanWithChildren:
|
||||
raise NotImplementedError("Console telemetry does not support span querying")
|
||||
|
||||
|
||||
COLORS = {
|
||||
|
|
|
|||
|
|
@ -22,5 +22,6 @@ async def get_provider_impl(
|
|||
impl = TorchtunePostTrainingImpl(
|
||||
config,
|
||||
deps[Api.datasetio],
|
||||
deps[Api.datasets],
|
||||
)
|
||||
return impl
|
||||
|
|
|
|||
|
|
@ -15,10 +15,14 @@ from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetunin
|
|||
|
||||
class TorchtunePostTrainingImpl:
|
||||
def __init__(
|
||||
self, config: TorchtunePostTrainingConfig, datasetio_api: DatasetIO
|
||||
self,
|
||||
config: TorchtunePostTrainingConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets
|
||||
|
||||
# TODO: assume sync job, will need jobs API for async scheduling
|
||||
self.jobs_status = {}
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ class LoraFinetuningSingleDevice:
|
|||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[Union[LoraFinetuningConfig, QATFinetuningConfig]],
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
self.job_uuid = job_uuid
|
||||
self.training_config = training_config
|
||||
|
|
@ -106,7 +107,6 @@ class LoraFinetuningSingleDevice:
|
|||
model = resolve_model(self.model_id)
|
||||
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||
|
||||
# TODO @markchen1015 make it work with get_training_job_artifacts
|
||||
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||
|
||||
self.seed = training.set_seed(seed=config.torch_seed)
|
||||
|
|
@ -230,7 +230,7 @@ class LoraFinetuningSingleDevice:
|
|||
self._use_dora = self.algorithm_config.use_dora or False
|
||||
|
||||
with training.set_default_dtype(self._dtype), self._device:
|
||||
model_type = utils.get_model_type(self.model_id)
|
||||
model_type = await utils.get_model_definition(self.model_id)
|
||||
model = model_type(
|
||||
lora_attn_modules=self._lora_attn_modules,
|
||||
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||
|
|
@ -313,9 +313,11 @@ class LoraFinetuningSingleDevice:
|
|||
async def _setup_data(
|
||||
self, tokenizer: Llama3Tokenizer, shuffle: bool, batch_size: int
|
||||
) -> Tuple[DistributedSampler, DataLoader]:
|
||||
dataset_id = self.training_config.data_config.dataset_id
|
||||
|
||||
async def fetch_rows():
|
||||
return await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=self.training_config.data_config.dataset_id,
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
|
||||
|
|
@ -323,7 +325,13 @@ class LoraFinetuningSingleDevice:
|
|||
rows = all_rows.rows
|
||||
|
||||
# Curretly only support alpaca instruct dataset
|
||||
# TODO @markchen1015 make the message_transform swappable and support more dataset types
|
||||
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
||||
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
||||
await utils.validate_input_dataset_schema(
|
||||
datasets_api=self.datasets_api,
|
||||
dataset_id=dataset_id,
|
||||
dataset_type="alpaca",
|
||||
)
|
||||
ds = SFTDataset(
|
||||
rows,
|
||||
message_transform=AlpacaToMessages(train_on_input=False),
|
||||
|
|
|
|||
139
llama_stack/providers/inline/post_training/torchtune/utils.py
Normal file
139
llama_stack/providers/inline/post_training/torchtune/utils.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
||||
import torch
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.common.type_system import * # noqa
|
||||
from llama_models.datatypes import Model
|
||||
from llama_models.sku_list import resolve_model
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||
|
||||
|
||||
class ColumnName(Enum):
|
||||
instruction = "instruction"
|
||||
input = "input"
|
||||
output = "output"
|
||||
text = "text"
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
model_definition: Any
|
||||
tokenizer_type: Any
|
||||
checkpoint_type: str
|
||||
|
||||
|
||||
class DatasetSchema(BaseModel):
|
||||
alpaca: List[Dict[str, ParamType]]
|
||||
|
||||
|
||||
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||
"Llama3.2-3B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_2_3b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3_2",
|
||||
),
|
||||
"Llama-3-8B-Instruct": ModelConfig(
|
||||
model_definition=lora_llama3_8b,
|
||||
tokenizer_type=llama3_tokenizer,
|
||||
checkpoint_type="LLAMA3",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||
alpaca=[
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
ColumnName.text.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.input.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
{
|
||||
ColumnName.instruction.value: StringType(),
|
||||
ColumnName.output.value: StringType(),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
||||
|
||||
def _validate_model_id(model_id: str) -> Model:
|
||||
model = resolve_model(model_id)
|
||||
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||
raise ValueError(f"Model {model_id} is not supported.")
|
||||
return model
|
||||
|
||||
|
||||
async def get_model_definition(
|
||||
model_id: str,
|
||||
) -> BuildLoraModelCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "model_definition"):
|
||||
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||
return model_config.model_definition
|
||||
|
||||
|
||||
async def get_tokenizer_type(
|
||||
model_id: str,
|
||||
) -> BuildTokenizerCallable:
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "tokenizer_type"):
|
||||
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||
return model_config.tokenizer_type
|
||||
|
||||
|
||||
async def get_checkpointer_model_type(
|
||||
model_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||
"""
|
||||
model = _validate_model_id(model_id)
|
||||
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||
if not hasattr(model_config, "checkpoint_type"):
|
||||
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||
return model_config.checkpoint_type
|
||||
|
||||
|
||||
async def validate_input_dataset_schema(
|
||||
datasets_api: Datasets,
|
||||
dataset_id: str,
|
||||
dataset_type: str,
|
||||
) -> None:
|
||||
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||
|
||||
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||
|
||||
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||
)
|
||||
5
llama_stack/providers/inline/scoring/__init__.py
Normal file
5
llama_stack/providers/inline/scoring/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -113,7 +113,9 @@ class BasicScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
|
||||
from .fn_defs.equality import equality
|
||||
|
||||
|
|
@ -42,8 +42,3 @@ class EqualityScoringFn(BaseScoringFn):
|
|||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
|
|||
|
|
@ -5,14 +5,20 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
equality = ScoringFn(
|
||||
identifier="basic::equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
params=None,
|
||||
provider_id="basic",
|
||||
provider_resource_id="equality",
|
||||
return_type=NumberType(),
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
RegexParserScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
MULTILINGUAL_ANSWER_REGEXES = [
|
||||
r"Answer\s*:",
|
||||
|
|
@ -67,5 +70,6 @@ regex_parser_multiple_choice_answer = ScoringFn(
|
|||
MULTILINGUAL_ANSWER_PATTERN_TEMPLATE.format(x)
|
||||
for x in MULTILINGUAL_ANSWER_REGEXES
|
||||
],
|
||||
aggregation_functions=[AggregationFunctionType.accuracy],
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import (
|
||||
AggregationFunctionType,
|
||||
BasicScoringFnParams,
|
||||
ScoringFn,
|
||||
)
|
||||
|
||||
|
||||
subset_of = ScoringFn(
|
||||
|
|
@ -14,4 +18,7 @@ subset_of = ScoringFn(
|
|||
return_type=NumberType(),
|
||||
provider_id="basic",
|
||||
provider_resource_id="subset-of",
|
||||
params=BasicScoringFnParams(
|
||||
aggregation_functions=[AggregationFunctionType.accuracy]
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@
|
|||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams, ScoringFnParamsType
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||
|
||||
from .fn_defs.regex_parser_multiple_choice_answer import (
|
||||
regex_parser_multiple_choice_answer,
|
||||
|
|
@ -60,8 +60,3 @@ class RegexParserScoringFn(BaseScoringFn):
|
|||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_accuracy
|
||||
|
||||
from .fn_defs.subset_of import subset_of
|
||||
|
||||
|
|
@ -36,8 +36,3 @@ class SubsetOfScoringFn(BaseScoringFn):
|
|||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
return aggregate_accuracy(scoring_results)
|
||||
|
|
|
|||
|
|
@ -5,11 +5,17 @@
|
|||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import BraintrustScoringConfig
|
||||
|
||||
|
||||
class BraintrustProviderDataValidator(BaseModel):
|
||||
openai_api_key: str
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: BraintrustScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
|
|
|
|||
|
|
@ -12,9 +12,12 @@ from llama_stack.apis.common.type_system import * # noqa: F403
|
|||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
# from .scoring_fn.braintrust_scoring_fn import BraintrustScoringFn
|
||||
import os
|
||||
|
||||
from autoevals.llm import Factuality
|
||||
from autoevals.ragas import AnswerCorrectness
|
||||
|
||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_average
|
||||
|
|
@ -24,7 +27,9 @@ from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def
|
|||
from .scoring_fn.fn_defs.factuality import factuality_fn_def
|
||||
|
||||
|
||||
class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
class BraintrustScoringImpl(
|
||||
Scoring, ScoringFunctionsProtocolPrivate, NeedsRequestProviderData
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
config: BraintrustScoringConfig,
|
||||
|
|
@ -79,12 +84,25 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def set_api_key(self) -> None:
|
||||
# api key is in the request headers
|
||||
if not self.config.openai_api_key:
|
||||
provider_data = self.get_request_provider_data()
|
||||
if provider_data is None or not provider_data.openai_api_key:
|
||||
raise ValueError(
|
||||
'Pass OpenAI API Key in the header X-LlamaStack-ProviderData as { "openai_api_key": <your api key>}'
|
||||
)
|
||||
self.config.openai_api_key = provider_data.openai_api_key
|
||||
|
||||
os.environ["OPENAI_API_KEY"] = self.config.openai_api_key
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.set_api_key()
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
|
|
@ -105,6 +123,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score_row(
|
||||
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
|
||||
) -> ScoringResultRow:
|
||||
await self.set_api_key()
|
||||
assert scoring_fn_identifier is not None, "scoring_fn_identifier cannot be None"
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
|
|
@ -118,6 +137,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
await self.set_api_key()
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in self.supported_fn_defs_registry:
|
||||
|
|
@ -127,7 +147,7 @@ class BraintrustScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
await self.score_row(input_row, scoring_fn_id)
|
||||
for input_row in input_rows
|
||||
]
|
||||
|
||||
aggregation_functions = [AggregationFunctionType.average]
|
||||
agg_results = aggregate_average(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
|
|
|
|||
|
|
@ -6,4 +6,14 @@
|
|||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class BraintrustScoringConfig(BaseModel): ...
|
||||
class BraintrustScoringConfig(BaseModel):
|
||||
openai_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The OpenAI API Key",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_api_key": "${env.OPENAI_API_KEY:}",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from llama_stack.apis.scoring_functions import ScoringFn
|
|||
|
||||
answer_correctness_fn_def = ScoringFn(
|
||||
identifier="braintrust::answer-correctness",
|
||||
description="Test whether an output is factual, compared to an original (`expected`) value. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
description="Scores the correctness of the answer based on the ground truth.. One of Braintrust LLM basd scorer https://github.com/braintrustdata/autoevals/blob/main/py/autoevals/llm.py",
|
||||
params=None,
|
||||
provider_id="braintrust",
|
||||
provider_resource_id="answer-correctness",
|
||||
|
|
|
|||
|
|
@ -120,7 +120,9 @@ class LlmAsJudgeScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
|||
score_results = await scoring_fn.score(
|
||||
input_rows, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
agg_results = await scoring_fn.aggregate(score_results)
|
||||
agg_results = await scoring_fn.aggregate(
|
||||
score_results, scoring_fn_id, scoring_fn_params
|
||||
)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.scoring_functions import ScoringFn
|
||||
from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn
|
||||
|
||||
|
||||
llm_as_judge_base = ScoringFn(
|
||||
|
|
@ -14,4 +14,8 @@ llm_as_judge_base = ScoringFn(
|
|||
return_type=NumberType(),
|
||||
provider_id="llm-as-judge",
|
||||
provider_resource_id="llm-as-judge-base",
|
||||
params=LLMAsJudgeScoringFnParams(
|
||||
judge_model="meta-llama/Llama-3.1-405B-Instruct",
|
||||
prompt_template="Enter custom LLM as Judge Prompt Template",
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -3,13 +3,16 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import re
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from llama_stack.apis.inference.inference import Inference
|
||||
|
||||
from llama_stack.apis.scoring import ScoringResultRow
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
|
||||
from llama_stack.providers.utils.scoring.base_scoring_fn import BaseScoringFn
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
import re
|
||||
|
||||
from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa
|
||||
|
||||
|
|
@ -85,9 +88,3 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
|
|||
"score": judge_rating,
|
||||
"judge_feedback": content,
|
||||
}
|
||||
|
||||
async def aggregate(
|
||||
self, scoring_results: List[ScoringResultRow]
|
||||
) -> Dict[str, Any]:
|
||||
# TODO: this needs to be config based aggregation, and only useful w/ Jobs API
|
||||
return {}
|
||||
|
|
|
|||
5
llama_stack/providers/inline/telemetry/__init__.py
Normal file
5
llama_stack/providers/inline/telemetry/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
__all__ = ["TelemetryConfig", "TelemetrySink"]
|
||||
|
||||
|
||||
async def get_provider_impl(config: TelemetryConfig, deps: Dict[str, Any]):
|
||||
from .telemetry import TelemetryAdapter
|
||||
|
||||
impl = TelemetryAdapter(config, deps)
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -0,0 +1,58 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class TelemetrySink(str, Enum):
|
||||
OTEL = "otel"
|
||||
SQLITE = "sqlite"
|
||||
CONSOLE = "console"
|
||||
|
||||
|
||||
class TelemetryConfig(BaseModel):
|
||||
otel_endpoint: str = Field(
|
||||
default="http://localhost:4318/v1/traces",
|
||||
description="The OpenTelemetry collector endpoint URL",
|
||||
)
|
||||
service_name: str = Field(
|
||||
default="llama-stack",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
|
||||
)
|
||||
sqlite_db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "trace_store.db").as_posix(),
|
||||
description="The path to the SQLite database to use for storing traces",
|
||||
)
|
||||
|
||||
@field_validator("sinks", mode="before")
|
||||
@classmethod
|
||||
def validate_sinks(cls, v):
|
||||
if isinstance(v, str):
|
||||
return [TelemetrySink(sink.strip()) for sink in v.split(",")]
|
||||
return v
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls, __distro_dir__: str = "runtime", db_name: str = "trace_store.db"
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:llama-stack}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:~/.llama/"
|
||||
+ __distro_dir__
|
||||
+ "/"
|
||||
+ db_name
|
||||
+ "}",
|
||||
}
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.sdk.trace import ReadableSpan
|
||||
from opentelemetry.sdk.trace.export import SpanProcessor
|
||||
from opentelemetry.trace.status import StatusCode
|
||||
|
||||
# Colors for console output
|
||||
COLORS = {
|
||||
"reset": "\033[0m",
|
||||
"bold": "\033[1m",
|
||||
"dim": "\033[2m",
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
}
|
||||
|
||||
|
||||
class ConsoleSpanProcessor(SpanProcessor):
|
||||
|
||||
def __init__(self, print_attributes: bool = False):
|
||||
self.print_attributes = print_attributes
|
||||
|
||||
def on_start(self, span: ReadableSpan, parent_context=None) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.start_time / 1e9).strftime(
|
||||
"%H:%M:%S.%f"
|
||||
)[:-3]
|
||||
|
||||
print(
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[START]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
def on_end(self, span: ReadableSpan) -> None:
|
||||
if span.attributes and span.attributes.get("__autotraced__"):
|
||||
return
|
||||
|
||||
timestamp = datetime.utcfromtimestamp(span.end_time / 1e9).strftime(
|
||||
"%H:%M:%S.%f"
|
||||
)[:-3]
|
||||
|
||||
span_context = (
|
||||
f"{COLORS['dim']}{timestamp}{COLORS['reset']} "
|
||||
f"{COLORS['magenta']}[END]{COLORS['reset']} "
|
||||
f"{COLORS['dim']}{span.name}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
if span.status.status_code == StatusCode.ERROR:
|
||||
span_context += f"{COLORS['reset']} {COLORS['red']}[ERROR]{COLORS['reset']}"
|
||||
elif span.status.status_code != StatusCode.UNSET:
|
||||
span_context += f"{COLORS['reset']} [{span.status.status_code}]"
|
||||
|
||||
duration_ms = (span.end_time - span.start_time) / 1e6
|
||||
span_context += f"{COLORS['reset']} ({duration_ms:.2f}ms)"
|
||||
|
||||
print(span_context)
|
||||
|
||||
if self.print_attributes and span.attributes:
|
||||
for key, value in span.attributes.items():
|
||||
if key.startswith("__"):
|
||||
continue
|
||||
str_value = str(value)
|
||||
if len(str_value) > 1000:
|
||||
str_value = str_value[:997] + "..."
|
||||
print(f" {COLORS['dim']}{key}: {str_value}{COLORS['reset']}")
|
||||
|
||||
for event in span.events:
|
||||
event_time = datetime.utcfromtimestamp(event.timestamp / 1e9).strftime(
|
||||
"%H:%M:%S.%f"
|
||||
)[:-3]
|
||||
|
||||
severity = event.attributes.get("severity", "info")
|
||||
message = event.attributes.get("message", event.name)
|
||||
if isinstance(message, (dict, list)):
|
||||
message = json.dumps(message, indent=2)
|
||||
|
||||
severity_colors = {
|
||||
"error": f"{COLORS['bold']}{COLORS['red']}",
|
||||
"warn": f"{COLORS['bold']}{COLORS['yellow']}",
|
||||
"info": COLORS["white"],
|
||||
"debug": COLORS["dim"],
|
||||
}
|
||||
msg_color = severity_colors.get(severity, COLORS["white"])
|
||||
|
||||
print(
|
||||
f" {event_time} "
|
||||
f"{msg_color}[{severity.upper()}] "
|
||||
f"{message}{COLORS['reset']}"
|
||||
)
|
||||
|
||||
if event.attributes:
|
||||
for key, value in event.attributes.items():
|
||||
if key.startswith("__") or key in ["message", "severity"]:
|
||||
continue
|
||||
print(f" {COLORS['dim']}{key}: {value}{COLORS['reset']}")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the processor."""
|
||||
pass
|
||||
|
||||
def force_flush(self, timeout_millis: float = None) -> bool:
|
||||
"""Force flush any pending spans."""
|
||||
return True
|
||||
|
|
@ -0,0 +1,177 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from datetime import datetime
|
||||
|
||||
from opentelemetry.sdk.trace import SpanProcessor
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
|
||||
class SQLiteSpanProcessor(SpanProcessor):
|
||||
def __init__(self, conn_string):
|
||||
"""Initialize the SQLite span processor with a connection string."""
|
||||
self.conn_string = conn_string
|
||||
self.conn = None
|
||||
self.setup_database()
|
||||
|
||||
def _get_connection(self) -> sqlite3.Connection:
|
||||
"""Get the database connection."""
|
||||
if self.conn is None:
|
||||
self.conn = sqlite3.connect(self.conn_string, check_same_thread=False)
|
||||
return self.conn
|
||||
|
||||
def setup_database(self):
|
||||
"""Create the necessary tables if they don't exist."""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(self.conn_string), exist_ok=True)
|
||||
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS traces (
|
||||
trace_id TEXT PRIMARY KEY,
|
||||
service_name TEXT,
|
||||
root_span_id TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS spans (
|
||||
span_id TEXT PRIMARY KEY,
|
||||
trace_id TEXT REFERENCES traces(trace_id),
|
||||
parent_span_id TEXT,
|
||||
name TEXT,
|
||||
start_time TIMESTAMP,
|
||||
end_time TIMESTAMP,
|
||||
attributes TEXT,
|
||||
status TEXT,
|
||||
kind TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS span_events (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
span_id TEXT REFERENCES spans(span_id),
|
||||
name TEXT,
|
||||
timestamp TIMESTAMP,
|
||||
attributes TEXT
|
||||
)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_traces_created_at
|
||||
ON traces(created_at)
|
||||
"""
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
def on_start(self, span: Span, parent_context=None):
|
||||
"""Called when a span starts."""
|
||||
pass
|
||||
|
||||
def on_end(self, span: Span):
|
||||
"""Called when a span ends. Export the span data to SQLite."""
|
||||
try:
|
||||
conn = self._get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
trace_id = format(span.get_span_context().trace_id, "032x")
|
||||
span_id = format(span.get_span_context().span_id, "016x")
|
||||
service_name = span.resource.attributes.get("service.name", "unknown")
|
||||
|
||||
parent_span_id = None
|
||||
parent_context = span.parent
|
||||
if parent_context:
|
||||
parent_span_id = format(parent_context.span_id, "016x")
|
||||
|
||||
# Insert into traces
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO traces (
|
||||
trace_id, service_name, root_span_id, start_time, end_time
|
||||
) VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(trace_id) DO UPDATE SET
|
||||
root_span_id = COALESCE(root_span_id, excluded.root_span_id),
|
||||
start_time = MIN(excluded.start_time, start_time),
|
||||
end_time = MAX(excluded.end_time, end_time)
|
||||
""",
|
||||
(
|
||||
trace_id,
|
||||
service_name,
|
||||
(span_id if not parent_span_id else None),
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
),
|
||||
)
|
||||
|
||||
# Insert into spans
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO spans (
|
||||
span_id, trace_id, parent_span_id, name,
|
||||
start_time, end_time, attributes, status,
|
||||
kind
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
trace_id,
|
||||
parent_span_id,
|
||||
span.name,
|
||||
datetime.fromtimestamp(span.start_time / 1e9).isoformat(),
|
||||
datetime.fromtimestamp(span.end_time / 1e9).isoformat(),
|
||||
json.dumps(dict(span.attributes)),
|
||||
span.status.status_code.name,
|
||||
span.kind.name,
|
||||
),
|
||||
)
|
||||
|
||||
for event in span.events:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO span_events (
|
||||
span_id, name, timestamp, attributes
|
||||
) VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
span_id,
|
||||
event.name,
|
||||
datetime.fromtimestamp(event.timestamp / 1e9).isoformat(),
|
||||
json.dumps(dict(event.attributes)),
|
||||
),
|
||||
)
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
except Exception as e:
|
||||
print(f"Error exporting span to SQLite: {e}")
|
||||
|
||||
def shutdown(self):
|
||||
"""Cleanup any resources."""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
def force_flush(self, timeout_millis=30000):
|
||||
"""Force export of spans."""
|
||||
pass
|
||||
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import threading
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from opentelemetry import metrics, trace
|
||||
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
|
||||
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
|
||||
from opentelemetry.sdk.metrics import MeterProvider
|
||||
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor
|
||||
from opentelemetry.semconv.resource import ResourceAttributes
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.console_span_processor import (
|
||||
ConsoleSpanProcessor,
|
||||
)
|
||||
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.sqlite_span_processor import (
|
||||
SQLiteSpanProcessor,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.dataset_mixin import TelemetryDatasetMixin
|
||||
from llama_stack.providers.utils.telemetry.sqlite_trace_store import SQLiteTraceStore
|
||||
|
||||
from llama_stack.apis.telemetry import * # noqa: F403
|
||||
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
|
||||
from .config import TelemetryConfig, TelemetrySink
|
||||
|
||||
_GLOBAL_STORAGE = {
|
||||
"active_spans": {},
|
||||
"counters": {},
|
||||
"gauges": {},
|
||||
"up_down_counters": {},
|
||||
}
|
||||
_global_lock = threading.Lock()
|
||||
|
||||
|
||||
def string_to_trace_id(s: str) -> int:
|
||||
# Convert the string to bytes and then to an integer
|
||||
return int.from_bytes(s.encode(), byteorder="big", signed=False)
|
||||
|
||||
|
||||
def string_to_span_id(s: str) -> int:
|
||||
# Use only the first 8 bytes (64 bits) for span ID
|
||||
return int.from_bytes(s.encode()[:8], byteorder="big", signed=False)
|
||||
|
||||
|
||||
def is_tracing_enabled(tracer):
|
||||
with tracer.start_as_current_span("check_tracing") as span:
|
||||
return span.is_recording()
|
||||
|
||||
|
||||
class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
||||
def __init__(self, config: TelemetryConfig, deps: Dict[str, Any]) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = deps[Api.datasetio]
|
||||
|
||||
resource = Resource.create(
|
||||
{
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
provider = TracerProvider(resource=resource)
|
||||
trace.set_tracer_provider(provider)
|
||||
if TelemetrySink.OTEL in self.config.sinks:
|
||||
otlp_exporter = OTLPSpanExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
span_processor = BatchSpanProcessor(otlp_exporter)
|
||||
trace.get_tracer_provider().add_span_processor(span_processor)
|
||||
metric_reader = PeriodicExportingMetricReader(
|
||||
OTLPMetricExporter(
|
||||
endpoint=self.config.otel_endpoint,
|
||||
)
|
||||
)
|
||||
metric_provider = MeterProvider(
|
||||
resource=resource, metric_readers=[metric_reader]
|
||||
)
|
||||
metrics.set_meter_provider(metric_provider)
|
||||
self.meter = metrics.get_meter(__name__)
|
||||
if TelemetrySink.SQLITE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(
|
||||
SQLiteSpanProcessor(self.config.sqlite_db_path)
|
||||
)
|
||||
self.trace_store = SQLiteTraceStore(self.config.sqlite_db_path)
|
||||
if TelemetrySink.CONSOLE in self.config.sinks:
|
||||
trace.get_tracer_provider().add_span_processor(ConsoleSpanProcessor())
|
||||
self._lock = _global_lock
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
trace.get_tracer_provider().force_flush()
|
||||
trace.get_tracer_provider().shutdown()
|
||||
metrics.get_meter_provider().shutdown()
|
||||
|
||||
async def log_event(self, event: Event, ttl_seconds: int = 604800) -> None:
|
||||
if isinstance(event, UnstructuredLogEvent):
|
||||
self._log_unstructured(event, ttl_seconds)
|
||||
elif isinstance(event, MetricEvent):
|
||||
self._log_metric(event)
|
||||
elif isinstance(event, StructuredLogEvent):
|
||||
self._log_structured(event, ttl_seconds)
|
||||
else:
|
||||
raise ValueError(f"Unknown event type: {event}")
|
||||
|
||||
def _log_unstructured(self, event: UnstructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
# Use global storage instead of instance storage
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
|
||||
if span:
|
||||
timestamp_ns = int(event.timestamp.timestamp() * 1e9)
|
||||
span.add_event(
|
||||
name=event.type,
|
||||
attributes={
|
||||
"message": event.message,
|
||||
"severity": event.severity.value,
|
||||
"__ttl__": ttl_seconds,
|
||||
**event.attributes,
|
||||
},
|
||||
timestamp=timestamp_ns,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Warning: No active span found for span_id {span_id}. Dropping event: {event}"
|
||||
)
|
||||
|
||||
def _get_or_create_counter(self, name: str, unit: str) -> metrics.Counter:
|
||||
if name not in _GLOBAL_STORAGE["counters"]:
|
||||
_GLOBAL_STORAGE["counters"][name] = self.meter.create_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Counter for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["counters"][name]
|
||||
|
||||
def _get_or_create_gauge(self, name: str, unit: str) -> metrics.ObservableGauge:
|
||||
if name not in _GLOBAL_STORAGE["gauges"]:
|
||||
_GLOBAL_STORAGE["gauges"][name] = self.meter.create_gauge(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"Gauge for {name}",
|
||||
)
|
||||
return _GLOBAL_STORAGE["gauges"][name]
|
||||
|
||||
def _log_metric(self, event: MetricEvent) -> None:
|
||||
if isinstance(event.value, int):
|
||||
counter = self._get_or_create_counter(event.metric, event.unit)
|
||||
counter.add(event.value, attributes=event.attributes)
|
||||
elif isinstance(event.value, float):
|
||||
up_down_counter = self._get_or_create_up_down_counter(
|
||||
event.metric, event.unit
|
||||
)
|
||||
up_down_counter.add(event.value, attributes=event.attributes)
|
||||
|
||||
def _get_or_create_up_down_counter(
|
||||
self, name: str, unit: str
|
||||
) -> metrics.UpDownCounter:
|
||||
if name not in _GLOBAL_STORAGE["up_down_counters"]:
|
||||
_GLOBAL_STORAGE["up_down_counters"][name] = (
|
||||
self.meter.create_up_down_counter(
|
||||
name=name,
|
||||
unit=unit,
|
||||
description=f"UpDownCounter for {name}",
|
||||
)
|
||||
)
|
||||
return _GLOBAL_STORAGE["up_down_counters"][name]
|
||||
|
||||
def _log_structured(self, event: StructuredLogEvent, ttl_seconds: int) -> None:
|
||||
with self._lock:
|
||||
span_id = string_to_span_id(event.span_id)
|
||||
trace_id = string_to_trace_id(event.trace_id)
|
||||
tracer = trace.get_tracer(__name__)
|
||||
if event.attributes is None:
|
||||
event.attributes = {}
|
||||
event.attributes["__ttl__"] = ttl_seconds
|
||||
|
||||
if isinstance(event.payload, SpanStartPayload):
|
||||
# Check if span already exists to prevent duplicates
|
||||
if span_id in _GLOBAL_STORAGE["active_spans"]:
|
||||
return
|
||||
|
||||
parent_span = None
|
||||
if event.payload.parent_span_id:
|
||||
parent_span_id = string_to_span_id(event.payload.parent_span_id)
|
||||
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
|
||||
|
||||
context = trace.Context(trace_id=trace_id)
|
||||
if parent_span:
|
||||
context = trace.set_span_in_context(parent_span, context)
|
||||
|
||||
span = tracer.start_span(
|
||||
name=event.payload.name,
|
||||
context=context,
|
||||
attributes=event.attributes or {},
|
||||
)
|
||||
_GLOBAL_STORAGE["active_spans"][span_id] = span
|
||||
|
||||
elif isinstance(event.payload, SpanEndPayload):
|
||||
span = _GLOBAL_STORAGE["active_spans"].get(span_id)
|
||||
if span:
|
||||
if event.attributes:
|
||||
span.set_attributes(event.attributes)
|
||||
|
||||
status = (
|
||||
trace.Status(status_code=trace.StatusCode.OK)
|
||||
if event.payload.status == SpanStatus.OK
|
||||
else trace.Status(status_code=trace.StatusCode.ERROR)
|
||||
)
|
||||
span.set_status(status)
|
||||
span.end()
|
||||
_GLOBAL_STORAGE["active_spans"].pop(span_id, None)
|
||||
else:
|
||||
raise ValueError(f"Unknown structured log event: {event}")
|
||||
|
||||
async def query_traces(
|
||||
self,
|
||||
attribute_filters: Optional[List[QueryCondition]] = None,
|
||||
limit: Optional[int] = 100,
|
||||
offset: Optional[int] = 0,
|
||||
order_by: Optional[List[str]] = None,
|
||||
) -> List[Trace]:
|
||||
return await self.trace_store.query_traces(
|
||||
attribute_filters=attribute_filters,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
async def get_span_tree(
|
||||
self,
|
||||
span_id: str,
|
||||
attributes_to_return: Optional[List[str]] = None,
|
||||
max_depth: Optional[int] = None,
|
||||
) -> SpanWithChildren:
|
||||
return await self.trace_store.get_span_tree(
|
||||
span_id=span_id,
|
||||
attributes_to_return=attributes_to_return,
|
||||
max_depth=max_depth,
|
||||
)
|
||||
|
|
@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
|
|||
"transformers",
|
||||
"zmq",
|
||||
"lm-format-enforcer",
|
||||
"sentence-transformers",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.vllm",
|
||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::sentence-transformers",
|
||||
pip_packages=["sentence-transformers"],
|
||||
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
@ -61,6 +69,17 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.sample.SampleConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras",
|
||||
pip_packages=[
|
||||
"cerebras_cloud_sdk",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.cerebras",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
@ -150,4 +169,15 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=[
|
||||
"openai",
|
||||
],
|
||||
module="llama_stack.providers.remote.inference.nvidia",
|
||||
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
|
|
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||
module="llama_stack.providers.inline.memory.faiss",
|
||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -53,8 +55,16 @@ def available_providers() -> List[ProviderSpec]:
|
|||
adapter_type="chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb-client"],
|
||||
module="llama_stack.providers.remote.memory.chroma",
|
||||
config_class="llama_stack.distribution.datatypes.RemoteProviderConfig",
|
||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.memory,
|
||||
provider_type="inline::chromadb",
|
||||
pip_packages=EMBEDDING_DEPS + ["chromadb"],
|
||||
module="llama_stack.providers.inline.memory.chroma",
|
||||
config_class="llama_stack.providers.inline.memory.chroma.ChromaInlineImplConfig",
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -64,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.pgvector",
|
||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -74,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.memory,
|
||||
|
|
@ -83,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.sample",
|
||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||
),
|
||||
api_dependencies=[],
|
||||
),
|
||||
remote_provider_spec(
|
||||
Api.memory,
|
||||
|
|
@ -92,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.remote.memory.qdrant",
|
||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||
),
|
||||
api_dependencies=[Api.inference],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def available_providers() -> List[ProviderSpec]:
|
|||
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
),
|
||||
]
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue