Merge branch 'main' of https://github.com/santiagxf/llama-stack into santiagxf/azure-ai-inference

This commit is contained in:
Facundo Santiago 2024-11-08 15:04:48 +00:00
commit 75f742775d
98 changed files with 1131 additions and 586 deletions

View file

@ -22,6 +22,7 @@ pip install -r requirements.txt
pip install sphinx-autobuild
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
make html
sphinx-autobuild source build/html
```

View file

@ -1,15 +1,42 @@
# Remote-Hosted Distribution
Remote Hosted distributions are distributions connecting to remote hosted services through Llama Stack server. Inference is done through remote providers. These are useful if you have an API key for a remote inference provider like Fireworks, Together, etc.
Remote-Hosted distributions are available endpoints serving Llama Stack API that you can directly connect to.
| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Distribution | Endpoint | Inference | Agents | Memory | Safety | Telemetry |
|-------------|----------|-----------|---------|---------|---------|------------|
| Together | [https://llama-stack.together.ai](https://llama-stack.together.ai) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [https://llamastack-preview.fireworks.ai](https://llamastack-preview.fireworks.ai) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
```{toctree}
:maxdepth: 1
## Connecting to Remote-Hosted Distributions
fireworks
together
You can use `llama-stack-client` to interact with these endpoints. For example, to list the available models served by the Fireworks endpoint:
```bash
$ pip install llama-stack-client
$ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.ai
$ llama-stack-client models list
```
You will see outputs:
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
```
Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack.

View file

@ -8,6 +8,10 @@ We offer deployable distributions where you can host your own Llama Stack server
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Bedrock | [llamastack/distribution-bedrock](https://hub.docker.com/repository/docker/llamastack/distribution-bedrock/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/bedrock.html) | remote::bedrock | meta-reference | remote::weaviate | meta-reference | meta-reference |
```{toctree}
:maxdepth: 1
@ -17,4 +21,7 @@ meta-reference-quantized-gpu
ollama
tgi
dell-tgi
together
fireworks
bedrock
```

View file

@ -14,6 +14,7 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
@json_schema_type
@ -35,36 +36,57 @@ EvalCandidate = Annotated[
]
@json_schema_type
class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
@json_schema_type
class AppEvalTaskConfig(BaseModel):
type: Literal["app"] = "app"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
# we could optinally add any specific dataset config here
EvalTaskConfig = Annotated[
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
]
@json_schema_type
class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name
scores: Dict[str, ScoringResult]
class Eval(Protocol):
@webmethod(route="/eval/evaluate_batch", method="POST")
async def evaluate_batch(
@webmethod(route="/eval/run_eval", method="POST")
async def run_eval(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
task_id: str,
task_config: EvalTaskConfig,
) -> Job: ...
@webmethod(route="/eval/evaluate", method="POST")
async def evaluate(
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ...
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> EvaluateResponse: ...
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval_tasks import * # noqa: F401 F403

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@json_schema_type
class EvalTaskDef(BaseModel):
identifier: str
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class EvalTaskDefWithProvider(EvalTaskDef):
type: Literal["eval_task"] = "eval_task"
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval_tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ...
@webmethod(route="/eval_tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ...
@webmethod(route="/eval_tasks/register", method="POST")
async def register_eval_task(
self, eval_task_def: EvalTaskDefWithProvider
) -> None: ...

View file

@ -48,11 +48,13 @@ class Scoring(Protocol):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ...

View file

@ -4,34 +4,66 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
@json_schema_type
class Parameter(BaseModel):
name: str
type: ParamType
description: Optional[str] = None
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up?
@json_schema_type
class ScoringConfigType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
class LLMAsJudgeContext(BaseModel):
@json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.llm_as_judge.value] = (
ScoringConfigType.llm_as_judge.value
)
judge_model: str
prompt_template: Optional[str] = None
judge_score_regex: Optional[List[str]] = Field(
description="Regex to extract the score from the judge response",
default=None,
judge_score_regexes: Optional[List[str]] = Field(
description="Regexes to extract the answer from generated response",
default_factory=list,
)
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.regex_parser.value] = (
ScoringConfigType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
],
Field(discriminator="type"),
]
@json_schema_type
class ScoringFnDef(BaseModel):
identifier: str
@ -40,14 +72,13 @@ class ScoringFnDef(BaseModel):
default_factory=dict,
description="Any additional metadata for this definition",
)
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
context: Optional[LLMAsJudgeContext] = None
params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
# We can optionally add information here to support packaging of code, etc.

View file

@ -43,6 +43,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
AutoRoutedApiInfo(
routing_table_api=Api.eval_tasks,
router_api=Api.eval,
),
]

View file

@ -8,6 +8,8 @@ import inspect
from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
@ -15,6 +17,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
@ -46,6 +49,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
}
@ -56,6 +60,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
Api.safety: (ShieldsProtocolPrivate, Shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
}
@ -97,6 +102,12 @@ async def resolve_impls(
)
p = provider_registry[api][provider.provider_type]
if p.deprecation_warning:
cprint(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"red",
attrs=["bold"],
)
p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec(
spec=p,

View file

@ -12,6 +12,7 @@ from llama_stack.distribution.store import DistributionRegistry
from .routing_tables import (
DatasetsRoutingTable,
EvalTasksRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
@ -31,6 +32,7 @@ async def get_routing_table_impl(
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable,
}
if api.value not in api_to_tables:
@ -44,6 +46,7 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import (
DatasetIORouter,
EvalRouter,
InferenceRouter,
MemoryRouter,
SafetyRouter,
@ -56,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
"eval": EvalRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -14,6 +14,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
class MemoryRouter(Memory):
@ -211,16 +212,16 @@ class ScoringRouter(Scoring):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
@ -232,17 +233,87 @@ class ScoringRouter(Scoring):
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions:
for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
)
res.update(score_response.results)
return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
async def job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
job_id,
)
async def job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)

View file

@ -12,6 +12,8 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403
@ -40,6 +42,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -103,6 +107,11 @@ class CommonRoutingTableImpl(RoutingTable):
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
elif api == Api.eval:
p.eval_task_store = self
eval_tasks = await p.list_eval_tasks()
await add_objects(eval_tasks, pid, EvalTaskDefWithProvider)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
await p.shutdown()
@ -121,6 +130,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
else:
raise ValueError("Unknown routing table type")
@ -246,9 +257,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
await self.register_object(dataset_def)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
return await self.get_all_with_type("scoring_function")
return await self.get_all_with_type("scoring_fn")
async def get_scoring_function(
self, name: str
@ -259,3 +270,14 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
self, function_def: ScoringFnDefWithProvider
) -> None:
await self.register_object(function_def)
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]:
return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]:
return await self.get_object_by_identifier(name)
async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None:
await self.register_object(eval_task_def)

View file

@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import (
get_provider_registry,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.utils.telemetry.tracing import (
end_trace,
@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints
@ -281,21 +279,8 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
else:
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls:

View file

@ -9,9 +9,17 @@ from typing import Dict, List, Protocol
import pydantic
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore
from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol):
@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
self.cache[obj.identifier].append(obj)
return success
async def create_dist_registry(
config: StackRunConfig,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore

View file

@ -12,6 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.scoring_functions import ScoringFnDef
@ -35,6 +36,7 @@ class Api(Enum):
memory_banks = "memory_banks"
datasets = "datasets"
scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"
# built-in API
inspect = "inspect"
@ -70,6 +72,12 @@ class ScoringFunctionsProtocolPrivate(Protocol):
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
class EvalTasksProtocolPrivate(Protocol):
async def list_eval_tasks(self) -> List[EvalTaskDef]: ...
async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api
@ -82,6 +90,10 @@ class ProviderSpec(BaseModel):
default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality",
)
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
# used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list)

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from pydantic import BaseModel, Field
class MetaReferenceAgentsImplConfig(BaseModel):

View file

@ -11,9 +11,8 @@ from datetime import datetime
from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore
from pydantic import BaseModel
class AgentSessionInfo(BaseModel):

View file

@ -10,14 +10,13 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator,
MemoryQueryGeneratorConfig,
)
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
from ..safety import ShieldRunnerMixin
from .builtin import BaseTool

View file

@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
class MetaReferenceInferenceConfig(BaseModel):

View file

@ -35,13 +35,12 @@ from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt,
chat_completion_request_to_messages,
)
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from .config import (
Fp8QuantizationConfig,

View file

@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank,
)
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult

View file

@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from termcolor import cprint
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.providers.inline.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
from ..config import MetaReferenceQuantizedInferenceConfig
def swiglu_wrapper(

View file

@ -5,9 +5,9 @@
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
@json_schema_type

View file

@ -5,13 +5,13 @@
# the root directory of this source tree.
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from pydantic import BaseModel
@json_schema_type

View file

@ -8,10 +8,11 @@ import logging
from typing import Any, Dict, List, Optional
import faiss
import numpy as np
from numpy.typing import NDArray
import faiss
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403

View file

@ -6,13 +6,15 @@
from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from .config import MetaReferenceEvalConfig
@ -25,7 +27,7 @@ class ColumnName(Enum):
generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval):
class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__(
self,
config: MetaReferenceEvalConfig,
@ -43,10 +45,18 @@ class MetaReferenceEvalImpl(Eval):
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {}
self.eval_tasks = {}
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
self.eval_tasks[task_def.identifier] = task_def
async def list_eval_tasks(self) -> List[EvalTaskDef]:
return list(self.eval_tasks.values())
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
@ -70,21 +80,26 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
)
async def evaluate_batch(
async def run_eval(
self,
dataset_id: str,
candidate: EvalCandidate,
scoring_functions: List[str],
task_id: str,
task_config: EvalTaskConfig,
) -> Job:
task_def = self.eval_tasks[task_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.evaluate(
res = await self.evaluate_rows(
task_id=task_id,
input_rows=all_rows.rows,
candidate=candidate,
scoring_functions=scoring_functions,
task_config=task_config,
)
# TODO: currently needs to wait for generation before returning
@ -93,12 +108,14 @@ class MetaReferenceEvalImpl(Eval):
self.jobs[job_id] = res
return Job(job_id=job_id)
async def evaluate(
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent":
raise NotImplementedError(
"Evaluation with generation has not been implemented for agents"
@ -122,7 +139,10 @@ class MetaReferenceEvalImpl(Eval):
}
)
elif ColumnName.chat_completion_input.value in x:
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
messages = []
if candidate.system_message:
@ -147,23 +167,33 @@ class MetaReferenceEvalImpl(Eval):
for input_r, generated_r in zip(input_rows, generations)
]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
)
return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]:
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs:
return JobStatus.completed
return None
async def job_cancel(self, job_id: str) -> None:
async def job_cancel(self, task_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id)
async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(task_id, job_id)
if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -1,73 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
import pytest
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestFaissMemoryImpl:
@pytest.fixture
def faiss_impl(self):
# Create a temporary SQLite database file
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
return FaissMemoryImpl(config)
@pytest.mark.asyncio
async def test_initialize(self, faiss_impl):
# Test empty initialization
await faiss_impl.initialize()
assert len(faiss_impl.cache) == 0
# Test initialization with existing banks
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
# Register a bank and reinitialize to test loading
await faiss_impl.register_memory_bank(bank)
# Create new instance to test initialization with existing data
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert len(new_impl.cache) == 1
assert "test_bank" in new_impl.cache
@pytest.mark.asyncio
async def test_register_memory_bank(self, faiss_impl):
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await faiss_impl.initialize()
await faiss_impl.register_memory_bank(bank)
assert "test_bank" in faiss_impl.cache
assert faiss_impl.cache["test_bank"].bank == bank
# Verify persistence
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert "test_bank" in new_impl.cache
if __name__ == "__main__":
pytest.main([__file__])

View file

@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def)
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
raise NotImplementedError("Register scoring function not implemented yet")
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
@ -97,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -106,7 +105,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
input_rows=all_rows.rows,
scoring_functions=scoring_functions,
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
@ -118,14 +118,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id)
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,

View file

@ -36,7 +36,10 @@ class BaseScoringFn(ABC):
@abstractmethod
async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
raise NotImplementedError()
@ -50,8 +53,9 @@ class BaseScoringFn(ABC):
self,
input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]:
return [
await self.score_row(input_row, scoring_fn_identifier)
await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows
]

View file

@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (

View file

@ -28,9 +28,13 @@ llm_as_judge_8b_correctness = ScoringFnDef(
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"],
judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
),
)

View file

@ -36,31 +36,37 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
assert (
scoring_fn_identifier is not None
), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert (
fn_def.context.prompt_template is not None
fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found."
assert (
fn_def.context.judge_score_regex is not None
), "LLM Judge judge_score_regex not found."
fn_def.params.judge_score_regexes is not None
), "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.context.prompt_template.format(
judge_input_msg = fn_def.params.prompt_template.format(
input_query=input_query,
expected_answer=expected_answer,
generated_answer=generated_answer,
)
judge_response = await self.inference_api.chat_completion(
model=fn_def.context.judge_model,
model=fn_def.params.judge_model,
messages=[
{
"role": "user",
@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
],
)
content = judge_response.completion_message.content
rating_regexs = fn_def.context.judge_score_regex
rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None
for regex in rating_regexs:
for regex in rating_regexes:
match = re.search(regex, content)
if match:
judge_rating = int(match.group(1))

View file

@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn):
self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow:
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]

View file

@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
"scikit-learn",
]
+ kvstore_dependencies(),
module="llama_stack.providers.inline.meta_reference.agents",
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig",
module="llama_stack.providers.inline.agents.meta_reference",
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[
Api.inference,
Api.safety,

View file

@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference,
provider_type="meta-reference",
pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.meta_reference.inference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig",
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
"torchao==0.5.0",
]
),
module="llama_stack.providers.inline.meta_reference.inference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
),
remote_provider_spec(
api=Api.inference,
@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]:
],
module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
@ -149,13 +158,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.azure_ai_inference.AzureAIInferenceConfig",
),
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.vllm",
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
),
]

View file

@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
api=Api.memory,
provider_type="meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.meta_reference.memory",
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig",
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `faiss` provider instead.",
),
InlineProviderSpec(
api=Api.memory,
provider_type="faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
),
remote_provider_spec(
Api.memory,

View file

@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
"transformers",
"torch --index-url https://download.pytorch.org/whl/cpu",
],
module="llama_stack.providers.inline.meta_reference.safety",
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig",
module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
api_dependencies=[
Api.inference,
],
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[
"codeshield",
],
module="llama_stack.providers.inline.meta_reference.codeshield",
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig",
module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
api_dependencies=[],
),
]

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel):
default="https://api.fireworks.ai/inference",
description="The URL for the Fireworks server",
)
api_key: str = Field(
default="",
api_key: Optional[str] = Field(
default=None,
description="The Fireworks.ai API Key",
)

View file

@ -9,12 +9,11 @@ from typing import AsyncGenerator
from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = {
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
}
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None:
return
pass
async def shutdown(self) -> None:
pass
def _get_client(self) -> Fireworks:
fireworks_api_key = None
if self.config.api_key is not None:
fireworks_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
)
fireworks_api_key = provider_data.fireworks_api_key
return Fireworks(api_key=fireworks_api_key)
async def completion(
self,
model: str,
@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
stream=stream,
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_completion(request, client)
return self._stream_completion(request)
else:
return await self._nonstream_completion(request, client)
return await self._nonstream_completion(request)
async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks
self, request: CompletionRequest
) -> CompletionResponse:
params = await self._get_params(request)
r = await client.completion.acreate(**params)
r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter)
async def _stream_completion(
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
stream = client.completion.acreate(**params)
# Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return options
async def chat_completion(
self,
model: str,
@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs=logprobs,
)
client = Fireworks(api_key=self.config.api_key)
if stream:
return self._stream_chat_completion(request, client)
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request, client)
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
self, request: ChatCompletionRequest
) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await client.chat.completions.acreate(**params)
r = await self._get_client().chat.completions.acreate(**params)
else:
r = await client.completion.acreate(**params)
r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks
self, request: ChatCompletionRequest
) -> AsyncGenerator:
params = await self._get_params(request)
async def _to_async_generator():
if "messages" in params:
stream = client.chat.completions.acreate(**params)
stream = await self._get_client().chat.completions.acreate(**params)
else:
stream = client.completion.acreate(**params)
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(
stream, self.formatter
):
@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter
)
elif isinstance(request, CompletionRequest):
else:
assert (
not media_present
), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
# Fireworks always prepends with BOS
if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return {
"model": self.map_to_provider_model(request.model),
**input_dict,
"stream": request.stream,
**options,
**self._build_options(request.sampling_params, request.response_format),
}
async def embeddings(

View file

@ -4,9 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import TogetherImplConfig
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter

View file

@ -11,7 +11,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.agents import (
from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig,
)

View file

@ -153,4 +153,7 @@ pytest_plugins = [
"llama_stack.providers.tests.safety.fixtures",
"llama_stack.providers.tests.memory.fixtures",
"llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
]

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from .fixtures import DATASETIO_FIXTURES
def pytest_configure(config):
for fixture_name in DATASETIO_FIXTURES:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
if "datasetio_stack" in metafunc.fixturenames:
metafunc.parametrize(
"datasetio_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in DATASETIO_FIXTURES
],
indirect=True,
)

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def datasetio_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def datasetio_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
DATASETIO_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def datasetio_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
impls = await resolve_impls_for_test_v2(
[Api.datasetio],
{"datasetio": fixture.providers},
fixture.provider_data,
)
return impls[Api.datasetio], impls[Api.datasets]

View file

@ -1,4 +0,0 @@
providers:
- provider_id: test-meta
provider_type: meta-reference
config: {}

View file

@ -3,11 +3,10 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
@ -15,35 +14,11 @@ import base64
import mimetypes
from pathlib import Path
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def datasetio_settings():
impls = await resolve_impls_for_test(
Api.datasetio,
)
return {
"datasetio_impl": impls[Api.datasetio],
"datasets_impl": impls[Api.datasets],
}
# pytest llama_stack/providers/tests/datasetio/test_datasetio.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
def data_url_from_file(file_path: str) -> str:
@ -82,8 +57,7 @@ async def register_dataset(
dataset = DatasetDefWithProvider(
identifier=dataset_id,
provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None)
or os.environ["PROVIDER_ID"],
provider_id="",
url=URL(
uri=test_url,
),
@ -92,57 +66,47 @@ async def register_dataset(
await datasets_impl.register_dataset(dataset)
class TestDatasetIO:
@pytest.mark.asyncio
async def test_datasets_list(datasetio_settings):
async def test_datasets_list(self, datasetio_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
_, datasets_impl = datasetio_stack
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_datasets_register(datasetio_settings):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
datasets_impl = datasetio_settings["datasets_impl"]
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
# register same dataset with same id again will fail
async def test_register_dataset(self, datasetio_stack):
_, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
@pytest.mark.asyncio
async def test_get_rows_paginated(datasetio_settings):
datasetio_impl = datasetio_settings["datasetio_impl"]
datasets_impl = datasetio_settings["datasets_impl"]
async def test_get_rows_paginated(self, datasetio_stack):
datasetio_impl, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
provider = datasetio_impl.routing_table.get_provider_impl("test_dataset")
if provider.__provider_spec__.provider_type == "remote":
pytest.skip("remote provider doesn't support get_rows_paginated")
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "fireworks",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
),
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_eval_fireworks_inference",
"meta_reference_eval_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
def pytest_generate_tests(metafunc):
if "eval_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": EVAL_FIXTURES,
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("eval_stack", combinations, indirect=True)

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def eval_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def eval_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
EVAL_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def eval_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
providers,
provider_data,
)
return impls

View file

@ -1,22 +0,0 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}
eval:
- provider_id: test-meta
provider_type: meta-reference
config: {}
inference:
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
- provider_id: test-tgi-2
provider_type: remote::tgi
config:
url: http://127.0.0.1:5010

View file

@ -3,79 +3,122 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.eval.eval import ModelCandidate
from llama_stack.distribution.datatypes import * # noqa: F403
import pytest
from llama_models.llama3.api import SamplingParams
from llama_stack.apis.eval.eval import (
AppEvalTaskConfig,
EvalTaskDefWithProvider,
ModelCandidate,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/eval/test_eval.py \
# --tb=short --disable-warnings
# ```
# pytest llama_stack/providers/tests/eval/test_eval.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
@pytest_asyncio.fixture(scope="session")
async def eval_settings():
impls = await resolve_impls_for_test(
Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference]
)
return {
"eval_impl": impls[Api.eval],
"scoring_impl": impls[Api.scoring],
"datasets_impl": impls[Api.datasets],
}
class Testeval:
@pytest.mark.asyncio
async def test_eval_tasks_list(self, eval_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
eval_tasks_impl = eval_stack[Api.eval_tasks]
response = await eval_tasks_impl.list_eval_tasks()
assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_eval(eval_settings):
datasets_impl = eval_settings["datasets_impl"]
await register_dataset(
datasets_impl,
for_generation=True,
dataset_id="test_dataset_for_eval",
async def test_eval_evaluate_rows(self, eval_stack):
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasetio],
eval_stack[Api.datasets],
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
response = await datasets_impl.list_datasets()
assert len(response) == 1
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset_for_eval",
rows_in_page=3,
)
assert len(rows.rows) == 3
eval_impl = eval_settings["eval_impl"]
response = await eval_impl.evaluate_batch(
dataset_id=response[0].identifier,
candidate=ModelCandidate(
model="Llama3.2-1B-Instruct",
scoring_functions = [
"meta-reference::llm_as_judge_8b_correctness",
"meta-reference::equality",
]
task_id = "meta-reference::app_eval"
task_def = EvalTaskDefWithProvider(
identifier=task_id,
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
provider_id="meta-reference",
)
await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.evaluate_rows(
task_id=task_id,
input_rows=rows.rows,
scoring_functions=scoring_functions,
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
),
),
)
assert len(response.generations) == 3
assert "meta-reference::llm_as_judge_8b_correctness" in response.scores
assert "meta-reference::equality" in response.scores
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack):
eval_impl, eval_tasks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
scoring_functions = [
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
],
"meta-reference::subset_of",
]
task_id = "meta-reference::app_eval-2"
task_def = EvalTaskDefWithProvider(
identifier=task_id,
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
provider_id="meta-reference",
)
await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.run_eval(
task_id=task_id,
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
),
),
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(response.job_id)
job_status = await eval_impl.job_status(task_id, response.job_id)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(response.job_id)
eval_response = await eval_impl.job_result(task_id, response.job_id)
assert eval_response is not None
assert len(eval_response.generations) == 5

View file

@ -10,7 +10,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.inference import (
from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig,
)
@ -64,6 +64,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model
)
print("!!!", inference_model)
if "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")

View file

@ -11,7 +11,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig

View file

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.safety import (
from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig,
SafetyConfig,
)

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "fireworks",
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
),
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
def pytest_generate_tests(metafunc):
if "scoring_stack" in metafunc.fixturenames:
available_fixtures = {
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("scoring_stack", combinations, indirect=True)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def scoring_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def scoring_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
)
return (
impls[Api.scoring],
impls[Api.scoring_functions],
impls[Api.datasetio],
impls[Api.datasets],
)

View file

@ -1,17 +0,0 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}
- provider_id: test-braintrust
provider_type: braintrust
config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009

View file

@ -3,150 +3,109 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import pytest
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
# --tb=short --disable-warnings
# ```
# pytest llama_stack/providers/tests/scoring/test_scoring.py
# -m "meta_reference"
# -v -s --tb=short --disable-warnings
@pytest_asyncio.fixture(scope="session")
async def scoring_settings():
impls = await resolve_impls_for_test(
Api.scoring, deps=[Api.datasetio, Api.inference]
)
return {
"scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions],
"datasets_impl": impls[Api.datasets],
}
@pytest_asyncio.fixture(scope="session")
async def provider_scoring_functions():
return {
"meta-reference": {
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
},
"braintrust": {
"braintrust::factuality",
"braintrust::answer-correctness",
},
}
class TestScoring:
@pytest.mark.asyncio
async def test_scoring_functions_list(self, scoring_stack):
# NOTE: this needs you to ensure that you are starting from a clean state
# but so far we don't have an unregister API unfortunately, so be careful
_, scoring_functions_impl, _, _ = scoring_stack
response = await scoring_functions_impl.list_scoring_functions()
assert isinstance(response, list)
assert len(response) > 0
@pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
# get current provider_type we're testing
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
for x in provider_scoring_functions[provider_type]:
assert x in function_ids
@pytest.mark.asyncio
async def test_scoring_functions_register(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
datasets_impl = scoring_settings["datasets_impl"]
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
if provider_type not in ("meta-reference"):
pytest.skip(
"Other scoring providers don't support registering scoring functions."
async def test_scoring_score(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>"""
# register the scoring function
await scoring_functions_impl.register_scoring_function(
ScoringFnDefWithProvider(
identifier="meta-reference::llm_as_judge_8b_random",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
prompt_template=test_prompt,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Number: (\d+)"],
),
provider_id="test-meta",
)
)
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "meta-reference::llm_as_judge_8b_random" in function_ids
# test score using newly registered scoring function
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=[
"meta-reference::llm_as_judge_8b_random",
],
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert "meta-reference::llm_as_judge_8b_random" in response.results
assert len(rows.rows) == 3
@pytest.mark.asyncio
async def test_scoring_score(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
datasets_impl = scoring_settings["datasets_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=list(provider_scoring_functions[provider_type]),
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": None,
"meta-reference::equality": None,
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(provider_scoring_functions[provider_type])
for x in provider_scoring_functions[provider_type]:
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5
@pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
# scoring individual rows
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert len(rows.rows) == 3
scoring_functions = {
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],
)
}
response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# score batch
response = await scoring_impl.score_batch(
dataset_id="test_dataset",
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == 5