Merge branch 'main' of https://github.com/santiagxf/llama-stack into santiagxf/azure-ai-inference

This commit is contained in:
Facundo Santiago 2024-11-08 15:04:48 +00:00
commit 75f742775d
98 changed files with 1131 additions and 586 deletions

View file

@ -22,6 +22,7 @@ pip install -r requirements.txt
pip install sphinx-autobuild pip install sphinx-autobuild
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
make html
sphinx-autobuild source build/html sphinx-autobuild source build/html
``` ```

View file

@ -1,15 +1,42 @@
# Remote-Hosted Distribution # Remote-Hosted Distribution
Remote Hosted distributions are distributions connecting to remote hosted services through Llama Stack server. Inference is done through remote providers. These are useful if you have an API key for a remote inference provider like Fireworks, Together, etc. Remote-Hosted distributions are available endpoints serving Llama Stack API that you can directly connect to.
| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | | Distribution | Endpoint | Inference | Agents | Memory | Safety | Telemetry |
|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: | |-------------|----------|-----------|---------|---------|---------|------------|
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference | | Together | [https://llama-stack.together.ai](https://llama-stack.together.ai) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference | | Fireworks | [https://llamastack-preview.fireworks.ai](https://llamastack-preview.fireworks.ai) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
```{toctree} ## Connecting to Remote-Hosted Distributions
:maxdepth: 1
fireworks You can use `llama-stack-client` to interact with these endpoints. For example, to list the available models served by the Fireworks endpoint:
together
```bash
$ pip install llama-stack-client
$ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.ai
$ llama-stack-client models list
``` ```
You will see outputs:
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
```
Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack.

View file

@ -8,6 +8,10 @@ We offer deployable distributions where you can host your own Llama Stack server
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Bedrock | [llamastack/distribution-bedrock](https://hub.docker.com/repository/docker/llamastack/distribution-bedrock/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/bedrock.html) | remote::bedrock | meta-reference | remote::weaviate | meta-reference | meta-reference |
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
@ -17,4 +21,7 @@ meta-reference-quantized-gpu
ollama ollama
tgi tgi
dell-tgi dell-tgi
together
fireworks
bedrock
``` ```

View file

@ -14,6 +14,7 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
@json_schema_type @json_schema_type
@ -35,36 +36,57 @@ EvalCandidate = Annotated[
] ]
@json_schema_type
class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
@json_schema_type
class AppEvalTaskConfig(BaseModel):
type: Literal["app"] = "app"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
# we could optinally add any specific dataset config here
EvalTaskConfig = Annotated[
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
]
@json_schema_type @json_schema_type
class EvaluateResponse(BaseModel): class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]] generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
scores: Dict[str, ScoringResult] scores: Dict[str, ScoringResult]
class Eval(Protocol): class Eval(Protocol):
@webmethod(route="/eval/evaluate_batch", method="POST") @webmethod(route="/eval/run_eval", method="POST")
async def evaluate_batch( async def run_eval(
self, self,
dataset_id: str, task_id: str,
candidate: EvalCandidate, task_config: EvalTaskConfig,
scoring_functions: List[str],
) -> Job: ... ) -> Job: ...
@webmethod(route="/eval/evaluate", method="POST") @webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate( async def evaluate_rows(
self, self,
task_id: str,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str], scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse: ... ) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET") @webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> Optional[JobStatus]: ... async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST") @webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ... async def job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET") @webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> EvaluateResponse: ... async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval_tasks import * # noqa: F401 F403

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
@json_schema_type
class EvalTaskDef(BaseModel):
identifier: str
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class EvalTaskDefWithProvider(EvalTaskDef):
type: Literal["eval_task"] = "eval_task"
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval_tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ...
@webmethod(route="/eval_tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ...
@webmethod(route="/eval_tasks/register", method="POST")
async def register_eval_task(
self, eval_task_def: EvalTaskDefWithProvider
) -> None: ...

View file

@ -48,11 +48,13 @@ class Scoring(Protocol):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score") @webmethod(route="/scoring/score")
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -4,34 +4,66 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
@json_schema_type
class Parameter(BaseModel):
name: str
type: ParamType
description: Optional[str] = None
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up? # with standard metrics so they can be rolled up?
@json_schema_type
class ScoringConfigType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
class LLMAsJudgeContext(BaseModel): @json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.llm_as_judge.value] = (
ScoringConfigType.llm_as_judge.value
)
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
judge_score_regex: Optional[List[str]] = Field( judge_score_regexes: Optional[List[str]] = Field(
description="Regex to extract the score from the judge response", description="Regexes to extract the answer from generated response",
default=None, default_factory=list,
) )
@json_schema_type
class RegexParserScoringFnParams(BaseModel):
type: Literal[ScoringConfigType.regex_parser.value] = (
ScoringConfigType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
],
Field(discriminator="type"),
]
@json_schema_type @json_schema_type
class ScoringFnDef(BaseModel): class ScoringFnDef(BaseModel):
identifier: str identifier: str
@ -40,14 +72,13 @@ class ScoringFnDef(BaseModel):
default_factory=dict, default_factory=dict,
description="Any additional metadata for this definition", description="Any additional metadata for this definition",
) )
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field( return_type: ParamType = Field(
description="The return type of the deterministic function", description="The return type of the deterministic function",
) )
context: Optional[LLMAsJudgeContext] = None params: Optional[ScoringFnParams] = Field(
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
# We can optionally add information here to support packaging of code, etc. # We can optionally add information here to support packaging of code, etc.

View file

@ -43,6 +43,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.scoring_functions, routing_table_api=Api.scoring_functions,
router_api=Api.scoring, router_api=Api.scoring,
), ),
AutoRoutedApiInfo(
routing_table_api=Api.eval_tasks,
router_api=Api.eval,
),
] ]

View file

@ -8,6 +8,8 @@ import inspect
from typing import Any, Dict, List, Set from typing import Any, Dict, List, Set
from termcolor import cprint
from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.providers.datatypes import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -15,6 +17,7 @@ from llama_stack.apis.agents import Agents
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
@ -46,6 +49,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring: Scoring, Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions, Api.scoring_functions: ScoringFunctions,
Api.eval: Eval, Api.eval: Eval,
Api.eval_tasks: EvalTasks,
} }
@ -56,6 +60,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
Api.safety: (ShieldsProtocolPrivate, Shields), Api.safety: (ShieldsProtocolPrivate, Shields),
Api.datasetio: (DatasetsProtocolPrivate, Datasets), Api.datasetio: (DatasetsProtocolPrivate, Datasets),
Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions),
Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks),
} }
@ -97,6 +102,12 @@ async def resolve_impls(
) )
p = provider_registry[api][provider.provider_type] p = provider_registry[api][provider.provider_type]
if p.deprecation_warning:
cprint(
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
"red",
attrs=["bold"],
)
p.deps__ = [a.value for a in p.api_dependencies] p.deps__ = [a.value for a in p.api_dependencies]
spec = ProviderWithSpec( spec = ProviderWithSpec(
spec=p, spec=p,

View file

@ -12,6 +12,7 @@ from llama_stack.distribution.store import DistributionRegistry
from .routing_tables import ( from .routing_tables import (
DatasetsRoutingTable, DatasetsRoutingTable,
EvalTasksRoutingTable,
MemoryBanksRoutingTable, MemoryBanksRoutingTable,
ModelsRoutingTable, ModelsRoutingTable,
ScoringFunctionsRoutingTable, ScoringFunctionsRoutingTable,
@ -31,6 +32,7 @@ async def get_routing_table_impl(
"shields": ShieldsRoutingTable, "shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable, "datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable,
} }
if api.value not in api_to_tables: if api.value not in api_to_tables:
@ -44,6 +46,7 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import ( from .routers import (
DatasetIORouter, DatasetIORouter,
EvalRouter,
InferenceRouter, InferenceRouter,
MemoryRouter, MemoryRouter,
SafetyRouter, SafetyRouter,
@ -56,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) ->
"safety": SafetyRouter, "safety": SafetyRouter,
"datasetio": DatasetIORouter, "datasetio": DatasetIORouter,
"scoring": ScoringRouter, "scoring": ScoringRouter,
"eval": EvalRouter,
} }
if api.value not in api_to_routers: if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map") raise ValueError(f"API {api.value} not found in router map")

View file

@ -14,6 +14,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
class MemoryRouter(Memory): class MemoryRouter(Memory):
@ -211,16 +212,16 @@ class ScoringRouter(Scoring):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
res = {} res = {}
for fn_identifier in scoring_functions: for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(
fn_identifier fn_identifier
).score_batch( ).score_batch(
dataset_id=dataset_id, dataset_id=dataset_id,
scoring_functions=[fn_identifier], scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
res.update(score_response.results) res.update(score_response.results)
@ -232,17 +233,87 @@ class ScoringRouter(Scoring):
) )
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
# look up and map each scoring function to its provider impl # look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions: for fn_identifier in scoring_functions.keys():
score_response = await self.routing_table.get_provider_impl( score_response = await self.routing_table.get_provider_impl(
fn_identifier fn_identifier
).score( ).score(
input_rows=input_rows, input_rows=input_rows,
scoring_functions=[fn_identifier], scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
) )
res.update(score_response.results) res.update(score_response.results)
return ScoreResponse(results=res) return ScoreResponse(results=res)
class EvalRouter(Eval):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
task_config=task_config,
)
@webmethod(route="/eval/evaluate_rows", method="POST")
async def evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(
task_id, job_id
)
async def job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
job_id,
)
async def job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)

View file

@ -12,6 +12,8 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -40,6 +42,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_dataset(obj) await p.register_dataset(obj)
elif api == Api.scoring: elif api == Api.scoring:
await p.register_scoring_function(obj) await p.register_scoring_function(obj)
elif api == Api.eval:
await p.register_eval_task(obj)
else: else:
raise ValueError(f"Unknown API {api} for registering object with provider") raise ValueError(f"Unknown API {api} for registering object with provider")
@ -103,6 +107,11 @@ class CommonRoutingTableImpl(RoutingTable):
scoring_functions = await p.list_scoring_functions() scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) await add_objects(scoring_functions, pid, ScoringFnDefWithProvider)
elif api == Api.eval:
p.eval_task_store = self
eval_tasks = await p.list_eval_tasks()
await add_objects(eval_tasks, pid, EvalTaskDefWithProvider)
async def shutdown(self) -> None: async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values(): for p in self.impls_by_provider_id.values():
await p.shutdown() await p.shutdown()
@ -121,6 +130,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("DatasetIO", "dataset") return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable): elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function") return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
else: else:
raise ValueError("Unknown routing table type") raise ValueError("Unknown routing table type")
@ -246,9 +257,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
await self.register_object(dataset_def) await self.register_object(dataset_def)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
return await self.get_all_with_type("scoring_function") return await self.get_all_with_type("scoring_fn")
async def get_scoring_function( async def get_scoring_function(
self, name: str self, name: str
@ -259,3 +270,14 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
self, function_def: ScoringFnDefWithProvider self, function_def: ScoringFnDefWithProvider
) -> None: ) -> None:
await self.register_object(function_def) await self.register_object(function_def)
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]:
return await self.get_all_with_type("eval_task")
async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]:
return await self.get_object_by_identifier(name)
async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None:
await self.register_object(eval_task_def)

View file

@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import (
get_provider_registry, get_provider_registry,
) )
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.store.registry import create_dist_registry
from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.providers.utils.telemetry.tracing import (
end_trace, end_trace,
@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.request_headers import set_request_provider_data
from llama_stack.distribution.resolver import resolve_impls from llama_stack.distribution.resolver import resolve_impls
from llama_stack.distribution.store import CachedDiskDistributionRegistry
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
from .endpoints import get_all_api_endpoints from .endpoints import get_all_api_endpoints
@ -281,21 +279,8 @@ def main(
config = StackRunConfig(**yaml.safe_load(fp)) config = StackRunConfig(**yaml.safe_load(fp))
app = FastAPI() app = FastAPI()
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
else:
dist_kvstore = asyncio.run(
kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
)
dist_registry = CachedDiskDistributionRegistry(dist_kvstore) dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
if Api.telemetry in impls: if Api.telemetry in impls:

View file

@ -9,9 +9,17 @@ from typing import Dict, List, Protocol
import pydantic import pydantic
from llama_stack.distribution.datatypes import RoutableObjectWithProvider from llama_stack.distribution.datatypes import (
RoutableObjectWithProvider,
StackRunConfig,
)
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import (
KVStore,
kvstore_impl,
SqliteKVStoreConfig,
)
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):
@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
self.cache[obj.identifier].append(obj) self.cache[obj.identifier].append(obj)
return success return success
async def create_dist_registry(
config: StackRunConfig,
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
# instantiate kvstore for storing and retrieving distribution metadata
if config.metadata_store:
dist_kvstore = await kvstore_impl(config.metadata_store)
else:
dist_kvstore = await kvstore_impl(
SqliteKVStoreConfig(
db_path=(
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
).as_posix()
)
)
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore

View file

@ -12,6 +12,7 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef from llama_stack.apis.models import ModelDef
from llama_stack.apis.scoring_functions import ScoringFnDef from llama_stack.apis.scoring_functions import ScoringFnDef
@ -35,6 +36,7 @@ class Api(Enum):
memory_banks = "memory_banks" memory_banks = "memory_banks"
datasets = "datasets" datasets = "datasets"
scoring_functions = "scoring_functions" scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
@ -70,6 +72,12 @@ class ScoringFunctionsProtocolPrivate(Protocol):
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ... async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ...
class EvalTasksProtocolPrivate(Protocol):
async def list_eval_tasks(self) -> List[EvalTaskDef]: ...
async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ...
@json_schema_type @json_schema_type
class ProviderSpec(BaseModel): class ProviderSpec(BaseModel):
api: Api api: Api
@ -82,6 +90,10 @@ class ProviderSpec(BaseModel):
default_factory=list, default_factory=list,
description="Higher-level API surfaces may depend on other providers to provide their functionality", description="Higher-level API surfaces may depend on other providers to provide their functionality",
) )
deprecation_warning: Optional[str] = Field(
default=None,
description="If this provider is deprecated, specify the warning message here",
)
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
deps__: List[str] = Field(default_factory=list) deps__: List[str] = Field(default_factory=list)

View file

@ -4,10 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from pydantic import BaseModel, Field
class MetaReferenceAgentsImplConfig(BaseModel): class MetaReferenceAgentsImplConfig(BaseModel):

View file

@ -11,9 +11,8 @@ from datetime import datetime
from typing import List, Optional from typing import List, Optional
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from pydantic import BaseModel
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
from pydantic import BaseModel
class AgentSessionInfo(BaseModel): class AgentSessionInfo(BaseModel):

View file

@ -10,14 +10,13 @@ from jinja2 import Template
from llama_models.llama3.api import * # noqa: F403 from llama_models.llama3.api import * # noqa: F403
from termcolor import cprint # noqa: F401
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
DefaultMemoryQueryGeneratorConfig, DefaultMemoryQueryGeneratorConfig,
LLMMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig,
MemoryQueryGenerator, MemoryQueryGenerator,
MemoryQueryGeneratorConfig, MemoryQueryGeneratorConfig,
) )
from termcolor import cprint # noqa: F401
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403

View file

@ -9,8 +9,7 @@ from typing import List
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin from ..safety import ShieldRunnerMixin
from .builtin import BaseTool from .builtin import BaseTool

View file

@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import * # noqa: F401, F403 from llama_stack.apis.inference import * # noqa: F401, F403
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
class MetaReferenceInferenceConfig(BaseModel): class MetaReferenceInferenceConfig(BaseModel):

View file

@ -35,13 +35,12 @@ from termcolor import cprint
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
augment_content_with_response_format_prompt, augment_content_with_response_format_prompt,
chat_completion_request_to_messages, chat_completion_request_to_messages,
) )
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
from .config import ( from .config import (
Fp8QuantizationConfig, Fp8QuantizationConfig,

View file

@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
get_model_parallel_src_rank, get_model_parallel_src_rank,
) )
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from torch.distributed.launcher.api import elastic_launch, LaunchConfig from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from typing_extensions import Annotated from typing_extensions import Annotated
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
from .generation import TokenResult from .generation import TokenResult

View file

@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model from llama_models.sku_list import resolve_model
from llama_stack.apis.inference import QuantizationType
from termcolor import cprint from termcolor import cprint
from torch import nn, Tensor from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType from ..config import MetaReferenceQuantizedInferenceConfig
from llama_stack.providers.inline.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
)
def swiglu_wrapper( def swiglu_wrapper(

View file

@ -5,9 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field, field_validator
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
from pydantic import BaseModel, Field, field_validator
@json_schema_type @json_schema_type

View file

@ -5,13 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
from llama_stack.providers.utils.kvstore.config import ( from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig, KVStoreConfig,
SqliteKVStoreConfig, SqliteKVStoreConfig,
) )
from pydantic import BaseModel
@json_schema_type @json_schema_type

View file

@ -8,10 +8,11 @@ import logging
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import faiss
import numpy as np import numpy as np
from numpy.typing import NDArray from numpy.typing import NDArray
import faiss
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403

View file

@ -6,13 +6,15 @@
from enum import Enum from enum import Enum
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from .....apis.common.job_types import Job
from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus from llama_stack.apis.eval_tasks import EvalTaskDef
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
from .config import MetaReferenceEvalConfig from .config import MetaReferenceEvalConfig
@ -25,7 +27,7 @@ class ColumnName(Enum):
generated_answer = "generated_answer" generated_answer = "generated_answer"
class MetaReferenceEvalImpl(Eval): class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
def __init__( def __init__(
self, self,
config: MetaReferenceEvalConfig, config: MetaReferenceEvalConfig,
@ -43,10 +45,18 @@ class MetaReferenceEvalImpl(Eval):
# TODO: assume sync job, will need jobs API for async scheduling # TODO: assume sync job, will need jobs API for async scheduling
self.jobs = {} self.jobs = {}
self.eval_tasks = {}
async def initialize(self) -> None: ... async def initialize(self) -> None: ...
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def register_eval_task(self, task_def: EvalTaskDef) -> None:
self.eval_tasks[task_def.identifier] = task_def
async def list_eval_tasks(self) -> List[EvalTaskDef]:
return list(self.eval_tasks.values())
async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
@ -70,21 +80,26 @@ class MetaReferenceEvalImpl(Eval):
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
) )
async def evaluate_batch( async def run_eval(
self, self,
dataset_id: str, task_id: str,
candidate: EvalCandidate, task_config: EvalTaskConfig,
scoring_functions: List[str],
) -> Job: ) -> Job:
task_def = self.eval_tasks[task_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated( all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id, dataset_id=dataset_id,
rows_in_page=-1, rows_in_page=-1,
) )
res = await self.evaluate( res = await self.evaluate_rows(
task_id=task_id,
input_rows=all_rows.rows, input_rows=all_rows.rows,
candidate=candidate,
scoring_functions=scoring_functions, scoring_functions=scoring_functions,
task_config=task_config,
) )
# TODO: currently needs to wait for generation before returning # TODO: currently needs to wait for generation before returning
@ -93,12 +108,14 @@ class MetaReferenceEvalImpl(Eval):
self.jobs[job_id] = res self.jobs[job_id] = res
return Job(job_id=job_id) return Job(job_id=job_id)
async def evaluate( async def evaluate_rows(
self, self,
task_id: str,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str], scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse: ) -> EvaluateResponse:
candidate = task_config.eval_candidate
if candidate.type == "agent": if candidate.type == "agent":
raise NotImplementedError( raise NotImplementedError(
"Evaluation with generation has not been implemented for agents" "Evaluation with generation has not been implemented for agents"
@ -122,7 +139,10 @@ class MetaReferenceEvalImpl(Eval):
} }
) )
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
input_messages = eval(str(x[ColumnName.chat_completion_input.value])) chat_completion_input_str = str(
x[ColumnName.chat_completion_input.value]
)
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages] input_messages = [UserMessage(**x) for x in input_messages]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
@ -147,23 +167,33 @@ class MetaReferenceEvalImpl(Eval):
for input_r, generated_r in zip(input_rows, generations) for input_r, generated_r in zip(input_rows, generations)
] ]
if task_config.type == "app" and task_config.scoring_params is not None:
scoring_functions_dict = {
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
for scoring_fn_id in scoring_functions
}
else:
scoring_functions_dict = {
scoring_fn_id: None for scoring_fn_id in scoring_functions
}
score_response = await self.scoring_api.score( score_response = await self.scoring_api.score(
input_rows=score_input_rows, scoring_functions=scoring_functions input_rows=score_input_rows, scoring_functions=scoring_functions_dict
) )
return EvaluateResponse(generations=generations, scores=score_response.results) return EvaluateResponse(generations=generations, scores=score_response.results)
async def job_status(self, job_id: str) -> Optional[JobStatus]: async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]:
if job_id in self.jobs: if job_id in self.jobs:
return JobStatus.completed return JobStatus.completed
return None return None
async def job_cancel(self, job_id: str) -> None: async def job_cancel(self, task_id: str, job_id: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet") raise NotImplementedError("Job cancel is not implemented yet")
async def job_result(self, job_id: str) -> EvaluateResponse: async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse:
status = await self.job_status(job_id) status = await self.job_status(task_id, job_id)
if not status or status != JobStatus.completed: if not status or status != JobStatus.completed:
raise ValueError(f"Job is not completed, Status: {status.value}") raise ValueError(f"Job is not completed, Status: {status.value}")

View file

@ -1,73 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import tempfile
import pytest
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestFaissMemoryImpl:
@pytest.fixture
def faiss_impl(self):
# Create a temporary SQLite database file
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
return FaissMemoryImpl(config)
@pytest.mark.asyncio
async def test_initialize(self, faiss_impl):
# Test empty initialization
await faiss_impl.initialize()
assert len(faiss_impl.cache) == 0
# Test initialization with existing banks
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
# Register a bank and reinitialize to test loading
await faiss_impl.register_memory_bank(bank)
# Create new instance to test initialization with existing data
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert len(new_impl.cache) == 1
assert "test_bank" in new_impl.cache
@pytest.mark.asyncio
async def test_register_memory_bank(self, faiss_impl):
bank = VectorMemoryBankDef(
identifier="test_bank",
type=MemoryBankType.vector.value,
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
)
await faiss_impl.initialize()
await faiss_impl.register_memory_bank(bank)
assert "test_bank" in faiss_impl.cache
assert faiss_impl.cache["test_bank"].bank == bank
# Verify persistence
new_impl = FaissMemoryImpl(faiss_impl.config)
await new_impl.initialize()
assert "test_bank" in new_impl.cache
if __name__ == "__main__":
pytest.main([__file__])

View file

@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
return scoring_fn_defs_list return scoring_fn_defs_list
async def register_scoring_function(self, function_def: ScoringFnDef) -> None: async def register_scoring_function(self, function_def: ScoringFnDef) -> None:
self.llm_as_judge_fn.register_scoring_fn_def(function_def) raise NotImplementedError("Register scoring function not implemented yet")
self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
@ -97,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
@ -106,7 +105,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
rows_in_page=-1, rows_in_page=-1,
) )
res = await self.score( res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions input_rows=all_rows.rows,
scoring_functions=scoring_functions,
) )
if save_results_dataset: if save_results_dataset:
# TODO: persist and register dataset on to server for reading # TODO: persist and register dataset on to server for reading
@ -118,14 +118,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
) )
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ) -> ScoreResponse:
res = {} res = {}
for scoring_fn_id in scoring_functions: for scoring_fn_id in scoring_functions.keys():
if scoring_fn_id not in self.scoring_fn_id_impls: if scoring_fn_id not in self.scoring_fn_id_impls:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
score_results = await scoring_fn.score(input_rows, scoring_fn_id) scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
score_results = await scoring_fn.score(
input_rows, scoring_fn_id, scoring_fn_params
)
agg_results = await scoring_fn.aggregate(score_results) agg_results = await scoring_fn.aggregate(score_results)
res[scoring_fn_id] = ScoringResult( res[scoring_fn_id] = ScoringResult(
score_rows=score_results, score_rows=score_results,

View file

@ -36,7 +36,10 @@ class BaseScoringFn(ABC):
@abstractmethod @abstractmethod
async def score_row( async def score_row(
self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None self,
input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
raise NotImplementedError() raise NotImplementedError()
@ -50,8 +53,9 @@ class BaseScoringFn(ABC):
self, self,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> List[ScoringResultRow]: ) -> List[ScoringResultRow]:
return [ return [
await self.score_row(input_row, scoring_fn_identifier) await self.score_row(input_row, scoring_fn_identifier, scoring_params)
for input_row in input_rows for input_row in input_rows
] ]

View file

@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn):
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "equality", scoring_fn_identifier: Optional[str] = "equality",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row." assert "expected_answer" in input_row, "Expected answer not found in input row."
assert ( assert (

View file

@ -28,9 +28,13 @@ llm_as_judge_8b_correctness = ScoringFnDef(
description="Llm As Judge Scoring Function", description="Llm As Judge Scoring Function",
parameters=[], parameters=[],
return_type=NumberType(), return_type=NumberType(),
context=LLMAsJudgeContext( params=LLMAsJudgeScoringFnParams(
prompt_template=JUDGE_PROMPT, prompt_template=JUDGE_PROMPT,
judge_model="Llama3.1-8B-Instruct", judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], judge_score_regexes=[
r"Total rating: (\d+)",
r"rating: (\d+)",
r"Rating: (\d+)",
],
), ),
) )

View file

@ -36,31 +36,37 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = None, scoring_fn_identifier: Optional[str] = None,
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
assert ( assert (
scoring_fn_identifier is not None scoring_fn_identifier is not None
), "Scoring function identifier not found." ), "Scoring function identifier not found."
fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] fn_def = self.supported_fn_defs_registry[scoring_fn_identifier]
assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}."
# override params if scoring_params is provided
if scoring_params is not None:
fn_def.params = scoring_params
assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}."
assert ( assert (
fn_def.context.prompt_template is not None fn_def.params.prompt_template is not None
), "LLM Judge prompt_template not found." ), "LLM Judge prompt_template not found."
assert ( assert (
fn_def.context.judge_score_regex is not None fn_def.params.judge_score_regexes is not None
), "LLM Judge judge_score_regex not found." ), "LLM Judge judge_score_regexes not found."
input_query = input_row["input_query"] input_query = input_row["input_query"]
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]
judge_input_msg = fn_def.context.prompt_template.format( judge_input_msg = fn_def.params.prompt_template.format(
input_query=input_query, input_query=input_query,
expected_answer=expected_answer, expected_answer=expected_answer,
generated_answer=generated_answer, generated_answer=generated_answer,
) )
judge_response = await self.inference_api.chat_completion( judge_response = await self.inference_api.chat_completion(
model=fn_def.context.judge_model, model=fn_def.params.judge_model,
messages=[ messages=[
{ {
"role": "user", "role": "user",
@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn):
], ],
) )
content = judge_response.completion_message.content content = judge_response.completion_message.content
rating_regexs = fn_def.context.judge_score_regex rating_regexes = fn_def.params.judge_score_regexes
judge_rating = None judge_rating = None
for regex in rating_regexs: for regex in rating_regexes:
match = re.search(regex, content) match = re.search(regex, content)
if match: if match:
judge_rating = int(match.group(1)) judge_rating = int(match.group(1))

View file

@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn):
self, self,
input_row: Dict[str, Any], input_row: Dict[str, Any],
scoring_fn_identifier: Optional[str] = "subset_of", scoring_fn_identifier: Optional[str] = "subset_of",
scoring_params: Optional[ScoringFnParams] = None,
) -> ScoringResultRow: ) -> ScoringResultRow:
expected_answer = input_row["expected_answer"] expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"] generated_answer = input_row["generated_answer"]

View file

@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
"scikit-learn", "scikit-learn",
] ]
+ kvstore_dependencies(), + kvstore_dependencies(),
module="llama_stack.providers.inline.meta_reference.agents", module="llama_stack.providers.inline.agents.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig", config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
Api.safety, Api.safety,

View file

@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
api=Api.inference, api=Api.inference,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=META_REFERENCE_DEPS, pip_packages=META_REFERENCE_DEPS,
module="llama_stack.providers.inline.meta_reference.inference", module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.inference, api=Api.inference,
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
"torchao==0.5.0", "torchao==0.5.0",
] ]
), ),
module="llama_stack.providers.inline.meta_reference.inference", module="llama_stack.providers.inline.inference.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
), ),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]:
], ],
module="llama_stack.providers.remote.inference.together", module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
), ),
), ),
remote_provider_spec( remote_provider_spec(
@ -149,13 +158,4 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.inference.azure_ai_inference.AzureAIInferenceConfig", config_class="llama_stack.providers.adapters.inference.azure_ai_inference.AzureAIInferenceConfig",
), ),
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="vllm",
pip_packages=[
"vllm",
],
module="llama_stack.providers.inline.vllm",
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
),
] ]

View file

@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
api=Api.memory, api=Api.memory,
provider_type="meta-reference", provider_type="meta-reference",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.meta_reference.memory", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `faiss` provider instead.",
),
InlineProviderSpec(
api=Api.memory,
provider_type="faiss",
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,

View file

@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
"transformers", "transformers",
"torch --index-url https://download.pytorch.org/whl/cpu", "torch --index-url https://download.pytorch.org/whl/cpu",
], ],
module="llama_stack.providers.inline.meta_reference.safety", module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig", config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
api_dependencies=[ api_dependencies=[
Api.inference, Api.inference,
], ],
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=[ pip_packages=[
"codeshield", "codeshield",
], ],
module="llama_stack.providers.inline.meta_reference.codeshield", module="llama_stack.providers.inline.safety.meta_reference",
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig", config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
api_dependencies=[], api_dependencies=[],
), ),
] ]

View file

@ -4,6 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel):
default="https://api.fireworks.ai/inference", default="https://api.fireworks.ai/inference",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
api_key: str = Field( api_key: Optional[str] = Field(
default="", default=None,
description="The Fireworks.ai API Key", description="The Fireworks.ai API Key",
) )

View file

@ -9,12 +9,11 @@ from typing import AsyncGenerator
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.datatypes import Message
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options, get_sampling_options,
@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
FIREWORKS_SUPPORTED_MODELS = { FIREWORKS_SUPPORTED_MODELS = {
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = {
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
} }
class FireworksInferenceAdapter(ModelRegistryHelper, Inference): class FireworksInferenceAdapter(
ModelRegistryHelper, Inference, NeedsRequestProviderData
):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__( ModelRegistryHelper.__init__(
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
self.formatter = ChatFormat(Tokenizer.get_instance()) self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self) -> None: async def initialize(self) -> None:
return pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def _get_client(self) -> Fireworks:
fireworks_api_key = None
if self.config.api_key is not None:
fireworks_api_key = self.config.api_key
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
)
fireworks_api_key = provider_data.fireworks_api_key
return Fireworks(api_key=fireworks_api_key)
async def completion( async def completion(
self, self,
model: str, model: str,
@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
stream=stream, stream=stream,
logprobs=logprobs, logprobs=logprobs,
) )
client = Fireworks(api_key=self.config.api_key)
if stream: if stream:
return self._stream_completion(request, client) return self._stream_completion(request)
else: else:
return await self._nonstream_completion(request, client) return await self._nonstream_completion(request)
async def _nonstream_completion( async def _nonstream_completion(
self, request: CompletionRequest, client: Fireworks self, request: CompletionRequest
) -> CompletionResponse: ) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await client.completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_completion_response(r, self.formatter) return process_completion_response(r, self.formatter)
async def _stream_completion( async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
self, request: CompletionRequest, client: Fireworks
) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
stream = client.completion.acreate(**params) # Wrapper for async generator similar
async def _to_async_generator():
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_completion_stream_response(stream, self.formatter): async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk yield chunk
def _build_options(
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
) -> dict:
options = get_sampling_options(sampling_params)
options.setdefault("max_tokens", 512)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return options
async def chat_completion( async def chat_completion(
self, self,
model: str, model: str,
@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
logprobs=logprobs, logprobs=logprobs,
) )
client = Fireworks(api_key=self.config.api_key)
if stream: if stream:
return self._stream_chat_completion(request, client) return self._stream_chat_completion(request)
else: else:
return await self._nonstream_chat_completion(request, client) return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion( async def _nonstream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks self, request: ChatCompletionRequest
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await client.chat.completions.acreate(**params) r = await self._get_client().chat.completions.acreate(**params)
else: else:
r = await client.completion.acreate(**params) r = await self._get_client().completion.acreate(**params)
return process_chat_completion_response(r, self.formatter) return process_chat_completion_response(r, self.formatter)
async def _stream_chat_completion( async def _stream_chat_completion(
self, request: ChatCompletionRequest, client: Fireworks self, request: ChatCompletionRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: async def _to_async_generator():
stream = client.chat.completions.acreate(**params) if "messages" in params:
else: stream = await self._get_client().chat.completions.acreate(**params)
stream = client.completion.acreate(**params) else:
stream = self._get_client().completion.create(**params)
for chunk in stream:
yield chunk
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response( async for chunk in process_chat_completion_stream_response(
stream, self.formatter stream, self.formatter
): ):
@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
input_dict["prompt"] = chat_completion_request_to_prompt( input_dict["prompt"] = chat_completion_request_to_prompt(
request, self.formatter request, self.formatter
) )
elif isinstance(request, CompletionRequest): else:
assert ( assert (
not media_present not media_present
), "Fireworks does not support media for Completion requests" ), "Fireworks does not support media for Completion requests"
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
else:
raise ValueError(f"Unknown request type {type(request)}")
# Fireworks always prepends with BOS # Fireworks always prepends with BOS
if "prompt" in input_dict: if "prompt" in input_dict:
if input_dict["prompt"].startswith("<|begin_of_text|>"): if input_dict["prompt"].startswith("<|begin_of_text|>"):
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {
"type": "json_object",
"schema": fmt.json_schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
options["response_format"] = {
"type": "grammar",
"grammar": fmt.bnf,
}
else:
raise ValueError(f"Unknown response format {fmt.type}")
return { return {
"model": self.map_to_provider_model(request.model), "model": self.map_to_provider_model(request.model),
**input_dict, **input_dict,
"stream": request.stream, "stream": request.stream,
**options, **self._build_options(request.sampling_params, request.response_format),
} }
async def embeddings( async def embeddings(

View file

@ -4,9 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import TogetherImplConfig from .config import TogetherImplConfig
class TogetherProviderDataValidator(BaseModel):
together_api_key: str
async def get_adapter_impl(config: TogetherImplConfig, _deps): async def get_adapter_impl(config: TogetherImplConfig, _deps):
from .together import TogetherInferenceAdapter from .together import TogetherInferenceAdapter

View file

@ -11,7 +11,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.agents import ( from llama_stack.providers.inline.agents.meta_reference import (
MetaReferenceAgentsImplConfig, MetaReferenceAgentsImplConfig,
) )

View file

@ -153,4 +153,7 @@ pytest_plugins = [
"llama_stack.providers.tests.safety.fixtures", "llama_stack.providers.tests.safety.fixtures",
"llama_stack.providers.tests.memory.fixtures", "llama_stack.providers.tests.memory.fixtures",
"llama_stack.providers.tests.agents.fixtures", "llama_stack.providers.tests.agents.fixtures",
"llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures",
] ]

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from .fixtures import DATASETIO_FIXTURES
def pytest_configure(config):
for fixture_name in DATASETIO_FIXTURES:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_generate_tests(metafunc):
if "datasetio_stack" in metafunc.fixturenames:
metafunc.parametrize(
"datasetio_stack",
[
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
for fixture_name in DATASETIO_FIXTURES
],
indirect=True,
)

View file

@ -0,0 +1,48 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def datasetio_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def datasetio_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
DATASETIO_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def datasetio_stack(request):
fixture_name = request.param
fixture = request.getfixturevalue(f"datasetio_{fixture_name}")
impls = await resolve_impls_for_test_v2(
[Api.datasetio],
{"datasetio": fixture.providers},
fixture.provider_data,
)
return impls[Api.datasetio], impls[Api.datasets]

View file

@ -1,4 +0,0 @@
providers:
- provider_id: test-meta
provider_type: meta-reference
config: {}

View file

@ -3,11 +3,10 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
@ -15,35 +14,11 @@ import base64
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test: # How to run this test:
# #
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky # pytest llama_stack/providers/tests/datasetio/test_datasetio.py
# since it depends on the provider you are testing. On top of that you need # -m "meta_reference"
# `pytest` and `pytest-asyncio` installed. # -v -s --tb=short --disable-warnings
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def datasetio_settings():
impls = await resolve_impls_for_test(
Api.datasetio,
)
return {
"datasetio_impl": impls[Api.datasetio],
"datasets_impl": impls[Api.datasets],
}
def data_url_from_file(file_path: str) -> str: def data_url_from_file(file_path: str) -> str:
@ -82,8 +57,7 @@ async def register_dataset(
dataset = DatasetDefWithProvider( dataset = DatasetDefWithProvider(
identifier=dataset_id, identifier=dataset_id,
provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None) provider_id="",
or os.environ["PROVIDER_ID"],
url=URL( url=URL(
uri=test_url, uri=test_url,
), ),
@ -92,57 +66,47 @@ async def register_dataset(
await datasets_impl.register_dataset(dataset) await datasets_impl.register_dataset(dataset)
@pytest.mark.asyncio class TestDatasetIO:
async def test_datasets_list(datasetio_settings): @pytest.mark.asyncio
# NOTE: this needs you to ensure that you are starting from a clean state async def test_datasets_list(self, datasetio_stack):
# but so far we don't have an unregister API unfortunately, so be careful # NOTE: this needs you to ensure that you are starting from a clean state
datasets_impl = datasetio_settings["datasets_impl"] # but so far we don't have an unregister API unfortunately, so be careful
response = await datasets_impl.list_datasets() _, datasets_impl = datasetio_stack
assert isinstance(response, list) response = await datasets_impl.list_datasets()
assert len(response) == 0 assert isinstance(response, list)
assert len(response) == 0
@pytest.mark.asyncio
async def test_register_dataset(self, datasetio_stack):
_, datasets_impl = datasetio_stack
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert isinstance(response, list)
assert len(response) == 1
assert response[0].identifier == "test_dataset"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_datasets_register(datasetio_settings): async def test_get_rows_paginated(self, datasetio_stack):
# NOTE: this needs you to ensure that you are starting from a clean state datasetio_impl, datasets_impl = datasetio_stack
# but so far we don't have an unregister API unfortunately, so be careful await register_dataset(datasets_impl)
datasets_impl = datasetio_settings["datasets_impl"] response = await datasetio_impl.get_rows_paginated(
await register_dataset(datasets_impl) dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
response = await datasets_impl.list_datasets() provider = datasetio_impl.routing_table.get_provider_impl("test_dataset")
assert isinstance(response, list) if provider.__provider_spec__.provider_type == "remote":
assert len(response) == 1 pytest.skip("remote provider doesn't support get_rows_paginated")
# register same dataset with same id again will fail # iterate over all rows
await register_dataset(datasets_impl) response = await datasetio_impl.get_rows_paginated(
response = await datasets_impl.list_datasets() dataset_id="test_dataset",
assert isinstance(response, list) rows_in_page=2,
assert len(response) == 1 page_token=response.next_page_token,
assert response[0].identifier == "test_dataset" )
assert isinstance(response.rows, list)
assert len(response.rows) == 2
@pytest.mark.asyncio assert response.next_page_token == "5"
async def test_get_rows_paginated(datasetio_settings):
datasetio_impl = datasetio_settings["datasetio_impl"]
datasets_impl = datasetio_settings["datasets_impl"]
await register_dataset(datasets_impl)
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=3,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 3
assert response.next_page_token == "3"
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from ..scoring.fixtures import SCORING_FIXTURES
from .fixtures import EVAL_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "fireworks",
},
id="meta_reference_eval_fireworks_inference",
marks=pytest.mark.meta_reference_eval_fireworks_inference,
),
pytest.param(
{
"eval": "meta_reference",
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_eval_together_inference",
marks=pytest.mark.meta_reference_eval_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_eval_fireworks_inference",
"meta_reference_eval_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
def pytest_generate_tests(metafunc):
if "eval_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": EVAL_FIXTURES,
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("eval_stack", combinations, indirect=True)

View file

@ -0,0 +1,55 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def eval_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def eval_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
EVAL_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def eval_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "eval", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
[Api.eval, Api.datasetio, Api.inference, Api.scoring],
providers,
provider_data,
)
return impls

View file

@ -1,22 +0,0 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}
eval:
- provider_id: test-meta
provider_type: meta-reference
config: {}
inference:
- provider_id: test-tgi
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
- provider_id: test-tgi-2
provider_type: remote::tgi
config:
url: http://127.0.0.1:5010

View file

@ -3,81 +3,124 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 import pytest
from llama_stack.apis.eval.eval import ModelCandidate
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_models.llama3.api import SamplingParams from llama_models.llama3.api import SamplingParams
from llama_stack.apis.eval.eval import (
AppEvalTaskConfig,
EvalTaskDefWithProvider,
ModelCandidate,
)
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test: # How to run this test:
# #
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky # pytest llama_stack/providers/tests/eval/test_eval.py
# since it depends on the provider you are testing. On top of that you need # -m "meta_reference"
# `pytest` and `pytest-asyncio` installed. # -v -s --tb=short --disable-warnings
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/eval/test_eval.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session") class Testeval:
async def eval_settings(): @pytest.mark.asyncio
impls = await resolve_impls_for_test( async def test_eval_tasks_list(self, eval_stack):
Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference] # NOTE: this needs you to ensure that you are starting from a clean state
) # but so far we don't have an unregister API unfortunately, so be careful
return { eval_tasks_impl = eval_stack[Api.eval_tasks]
"eval_impl": impls[Api.eval], response = await eval_tasks_impl.list_eval_tasks()
"scoring_impl": impls[Api.scoring], assert isinstance(response, list)
"datasets_impl": impls[Api.datasets], assert len(response) == 0
}
@pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack):
eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasetio],
eval_stack[Api.datasets],
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
response = await datasets_impl.list_datasets()
assert len(response) == 1
rows = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset_for_eval",
rows_in_page=3,
)
assert len(rows.rows) == 3
@pytest.mark.asyncio scoring_functions = [
async def test_eval(eval_settings):
datasets_impl = eval_settings["datasets_impl"]
await register_dataset(
datasets_impl,
for_generation=True,
dataset_id="test_dataset_for_eval",
)
response = await datasets_impl.list_datasets()
assert len(response) == 1
eval_impl = eval_settings["eval_impl"]
response = await eval_impl.evaluate_batch(
dataset_id=response[0].identifier,
candidate=ModelCandidate(
model="Llama3.2-1B-Instruct",
sampling_params=SamplingParams(),
),
scoring_functions=[
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness", "meta-reference::llm_as_judge_8b_correctness",
], "meta-reference::equality",
) ]
assert response.job_id == "0" task_id = "meta-reference::app_eval"
job_status = await eval_impl.job_status(response.job_id) task_def = EvalTaskDefWithProvider(
identifier=task_id,
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
provider_id="meta-reference",
)
await eval_tasks_impl.register_eval_task(task_def)
assert job_status and job_status.value == "completed" response = await eval_impl.evaluate_rows(
task_id=task_id,
input_rows=rows.rows,
scoring_functions=scoring_functions,
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
),
),
)
assert len(response.generations) == 3
assert "meta-reference::llm_as_judge_8b_correctness" in response.scores
assert "meta-reference::equality" in response.scores
eval_response = await eval_impl.job_result(response.job_id) @pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack):
eval_impl, eval_tasks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.eval_tasks],
eval_stack[Api.datasets],
)
await register_dataset(
datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval"
)
assert eval_response is not None scoring_functions = [
assert len(eval_response.generations) == 5 "meta-reference::llm_as_judge_8b_correctness",
assert "meta-reference::subset_of" in eval_response.scores "meta-reference::subset_of",
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores ]
task_id = "meta-reference::app_eval-2"
task_def = EvalTaskDefWithProvider(
identifier=task_id,
dataset_id="test_dataset_for_eval",
scoring_functions=scoring_functions,
provider_id="meta-reference",
)
await eval_tasks_impl.register_eval_task(task_def)
response = await eval_impl.run_eval(
task_id=task_id,
task_config=AppEvalTaskConfig(
eval_candidate=ModelCandidate(
model="Llama3.2-3B-Instruct",
sampling_params=SamplingParams(),
),
),
)
assert response.job_id == "0"
job_status = await eval_impl.job_status(task_id, response.job_id)
assert job_status and job_status.value == "completed"
eval_response = await eval_impl.job_result(task_id, response.job_id)
assert eval_response is not None
assert len(eval_response.generations) == 5
assert "meta-reference::subset_of" in eval_response.scores
assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores

View file

@ -10,7 +10,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.inference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
@ -64,6 +64,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = ( inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model [inference_model] if isinstance(inference_model, str) else inference_model
) )
print("!!!", inference_model)
if "Llama3.1-8B-Instruct" in inference_model: if "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")

View file

@ -11,7 +11,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig

View file

@ -8,7 +8,7 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.meta_reference.safety import ( from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig, LlamaGuardShieldConfig,
SafetyConfig, SafetyConfig,
) )

View file

@ -0,0 +1,68 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import SCORING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "fireworks",
},
id="meta_reference_scoring_fireworks_inference",
marks=pytest.mark.meta_reference_scoring_fireworks_inference,
),
pytest.param(
{
"scoring": "meta_reference",
"datasetio": "meta_reference",
"inference": "together",
},
id="meta_reference_scoring_together_inference",
marks=pytest.mark.meta_reference_scoring_together_inference,
),
]
def pytest_configure(config):
for fixture_name in [
"meta_reference_scoring_fireworks_inference",
"meta_reference_scoring_together_inference",
]:
config.addinivalue_line(
"markers",
f"{fixture_name}: marks tests as {fixture_name} specific",
)
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default="Llama3.2-3B-Instruct",
help="Specify the inference model to use for testing",
)
def pytest_generate_tests(metafunc):
if "scoring_stack" in metafunc.fixturenames:
available_fixtures = {
"scoring": SCORING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
"inference": INFERENCE_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("scoring_stack", combinations, indirect=True)

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture
@pytest.fixture(scope="session")
def scoring_remote() -> ProviderFixture:
return remote_stack_fixture()
@pytest.fixture(scope="session")
def scoring_meta_reference() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="meta-reference",
provider_type="meta-reference",
config={},
)
],
)
SCORING_FIXTURES = ["meta_reference", "remote"]
@pytest_asyncio.fixture(scope="session")
async def scoring_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["datasetio", "scoring", "inference"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
impls = await resolve_impls_for_test_v2(
[Api.scoring, Api.datasetio, Api.inference],
providers,
provider_data,
)
return (
impls[Api.scoring],
impls[Api.scoring_functions],
impls[Api.datasetio],
impls[Api.datasets],
)

View file

@ -1,17 +0,0 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}
- provider_id: test-braintrust
provider_type: braintrust
config: {}
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009

View file

@ -3,150 +3,109 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403 import pytest
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test: # How to run this test:
# #
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky # pytest llama_stack/providers/tests/scoring/test_scoring.py
# since it depends on the provider you are testing. On top of that you need # -m "meta_reference"
# `pytest` and `pytest-asyncio` installed. # -v -s --tb=short --disable-warnings
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session") class TestScoring:
async def scoring_settings(): @pytest.mark.asyncio
impls = await resolve_impls_for_test( async def test_scoring_functions_list(self, scoring_stack):
Api.scoring, deps=[Api.datasetio, Api.inference] # NOTE: this needs you to ensure that you are starting from a clean state
) # but so far we don't have an unregister API unfortunately, so be careful
return { _, scoring_functions_impl, _, _ = scoring_stack
"scoring_impl": impls[Api.scoring], response = await scoring_functions_impl.list_scoring_functions()
"scoring_functions_impl": impls[Api.scoring_functions], assert isinstance(response, list)
"datasets_impl": impls[Api.datasets], assert len(response) > 0
}
@pytest.mark.asyncio
@pytest_asyncio.fixture(scope="session") async def test_scoring_score(self, scoring_stack):
async def provider_scoring_functions(): scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
return { scoring_stack
"meta-reference": {
"meta-reference::equality",
"meta-reference::subset_of",
"meta-reference::llm_as_judge_8b_correctness",
},
"braintrust": {
"braintrust::factuality",
"braintrust::answer-correctness",
},
}
@pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings, provider_scoring_functions):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
# get current provider_type we're testing
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
for x in provider_scoring_functions[provider_type]:
assert x in function_ids
@pytest.mark.asyncio
async def test_scoring_functions_register(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
datasets_impl = scoring_settings["datasets_impl"]
# get current provider_type we're testing
scoring_functions = await scoring_functions_impl.list_scoring_functions()
function_ids = [f.identifier for f in scoring_functions]
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0])
provider_type = provider.__provider_spec__.provider_type
if provider_type not in ("meta-reference"):
pytest.skip(
"Other scoring providers don't support registering scoring functions."
) )
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: <answer>""" # scoring individual rows
# register the scoring function rows = await datasetio_impl.get_rows_paginated(
await scoring_functions_impl.register_scoring_function( dataset_id="test_dataset",
ScoringFnDefWithProvider( rows_in_page=3,
identifier="meta-reference::llm_as_judge_8b_random",
description="Llm As Judge Scoring Function",
parameters=[],
return_type=NumberType(),
context=LLMAsJudgeContext(
prompt_template=test_prompt,
judge_model="Llama3.1-8B-Instruct",
judge_score_regex=[r"Number: (\d+)"],
),
provider_id="test-meta",
) )
) assert len(rows.rows) == 3
scoring_functions = await scoring_functions_impl.list_scoring_functions() scoring_functions = {
assert isinstance(scoring_functions, list) "meta-reference::llm_as_judge_8b_correctness": None,
assert len(scoring_functions) > 0 "meta-reference::equality": None,
function_ids = [f.identifier for f in scoring_functions] }
assert "meta-reference::llm_as_judge_8b_random" in function_ids response = await scoring_impl.score(
input_rows=rows.rows,
scoring_functions=scoring_functions,
)
assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
# test score using newly registered scoring function # score batch
await register_dataset(datasets_impl) response = await scoring_impl.score_batch(
response = await datasets_impl.list_datasets() dataset_id="test_dataset",
assert len(response) == 1 scoring_functions=scoring_functions,
response = await scoring_impl.score_batch( )
dataset_id=response[0].identifier, assert len(response.results) == len(scoring_functions)
scoring_functions=[ for x in scoring_functions:
"meta-reference::llm_as_judge_8b_random", assert x in response.results
], assert len(response.results[x].score_rows) == 5
)
assert "meta-reference::llm_as_judge_8b_random" in response.results
@pytest.mark.asyncio
async def test_scoring_score_with_params(self, scoring_stack):
scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = (
scoring_stack
)
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
@pytest.mark.asyncio # scoring individual rows
async def test_scoring_score(scoring_settings, provider_scoring_functions): rows = await datasetio_impl.get_rows_paginated(
scoring_impl = scoring_settings["scoring_impl"] dataset_id="test_dataset",
datasets_impl = scoring_settings["datasets_impl"] rows_in_page=3,
scoring_functions_impl = scoring_settings["scoring_functions_impl"] )
await register_dataset(datasets_impl) assert len(rows.rows) == 3
response = await datasets_impl.list_datasets() scoring_functions = {
assert len(response) == 1 "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
judge_model="Llama3.1-405B-Instruct",
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
judge_score_regexes=[r"Score: (\d+)"],
)
}
# get current provider_type we're testing response = await scoring_impl.score(
scoring_functions = await scoring_functions_impl.list_scoring_functions() input_rows=rows.rows,
function_ids = [f.identifier for f in scoring_functions] scoring_functions=scoring_functions,
provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) )
provider_type = provider.__provider_spec__.provider_type assert len(response.results) == len(scoring_functions)
for x in scoring_functions:
assert x in response.results
assert len(response.results[x].score_rows) == len(rows.rows)
response = await scoring_impl.score_batch( # score batch
dataset_id=response[0].identifier, response = await scoring_impl.score_batch(
scoring_functions=list(provider_scoring_functions[provider_type]), dataset_id="test_dataset",
) scoring_functions=scoring_functions,
)
assert len(response.results) == len(provider_scoring_functions[provider_type]) assert len(response.results) == len(scoring_functions)
for x in provider_scoring_functions[provider_type]: for x in scoring_functions:
assert x in response.results assert x in response.results
assert len(response.results[x].score_rows) == 5