mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Merge branch 'eval_task_register' into mmlu_benchmark
This commit is contained in:
commit
cc6edf6287
72 changed files with 306 additions and 304 deletions
29
.github/PULL_REQUEST_TEMPLATE.md
vendored
29
.github/PULL_REQUEST_TEMPLATE.md
vendored
|
@ -1,17 +1,14 @@
|
||||||
# What does this PR do?
|
# What does this PR do?
|
||||||
|
|
||||||
Closes # (issue)
|
In short, provide a summary of what this PR does and why. Usually, the relevant context should be present in a linked issue.
|
||||||
|
|
||||||
|
- [ ] Addresses issue (#issue)
|
||||||
|
|
||||||
## Feature/Issue validation/testing/test plan
|
## Feature/Issue validation/testing/test plan
|
||||||
|
|
||||||
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
|
Please describe:
|
||||||
Please also list any relevant details for your test configuration or test plan.
|
- tests you ran to verify your changes with result summaries.
|
||||||
|
- provide instructions so it can be reproduced.
|
||||||
- [ ] Test A
|
|
||||||
Logs for Test A
|
|
||||||
|
|
||||||
- [ ] Test B
|
|
||||||
Logs for Test B
|
|
||||||
|
|
||||||
|
|
||||||
## Sources
|
## Sources
|
||||||
|
@ -20,12 +17,10 @@ Please link relevant resources if necessary.
|
||||||
|
|
||||||
|
|
||||||
## Before submitting
|
## Before submitting
|
||||||
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
|
||||||
- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
|
|
||||||
Pull Request section?
|
|
||||||
- [ ] Was this discussed/approved via a Github issue? Please add a link
|
|
||||||
to it if that's the case.
|
|
||||||
- [ ] Did you make sure to update the documentation with your changes?
|
|
||||||
- [ ] Did you write any new necessary tests?
|
|
||||||
|
|
||||||
Thanks for contributing 🎉!
|
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
||||||
|
- [ ] Ran pre-commit to handle lint / formatting issues.
|
||||||
|
- [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
|
||||||
|
Pull Request section?
|
||||||
|
- [ ] Updated relevant documentation.
|
||||||
|
- [ ] Wrote necessary unit or integration tests.
|
||||||
|
|
|
@ -38,14 +38,16 @@ EvalCandidate = Annotated[
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class BenchmarkEvalTaskConfig(BaseModel):
|
class BenchmarkEvalTaskConfig(BaseModel):
|
||||||
eval_candidate: EvalCandidate # type: ignore
|
type: Literal["benchmark"] = "benchmark"
|
||||||
|
eval_candidate: EvalCandidate
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AppEvalTaskConfig(BaseModel):
|
class AppEvalTaskConfig(BaseModel):
|
||||||
eval_candidate: EvalCandidate # type: ignore
|
type: Literal["app"] = "app"
|
||||||
scoring_params: Dict[str, ScoringFnParams] = Field( # type: ignore
|
eval_candidate: EvalCandidate
|
||||||
description="Map between scoring function id and parameters",
|
scoring_params: Dict[str, ScoringFnParams] = Field(
|
||||||
|
description="Map between scoring function id and parameters for each scoring function you want to run",
|
||||||
default_factory=dict,
|
default_factory=dict,
|
||||||
)
|
)
|
||||||
# we could optinally add any specific dataset config here
|
# we could optinally add any specific dataset config here
|
||||||
|
@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Eval(Protocol):
|
class Eval(Protocol):
|
||||||
@webmethod(route="/eval/run_benchmark_eval", method="POST")
|
@webmethod(route="/eval/run_benchmark", method="POST")
|
||||||
async def run_benchmark_eval(
|
async def run_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
eval_task_config: BenchmarkEvalTaskConfig,
|
benchmark_config: BenchmarkEvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/run_eval", method="POST")
|
@webmethod(route="/eval/run_eval", method="POST")
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
eval_task_def: EvalTaskDef,
|
task: EvalTaskDef,
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||||
|
@ -83,14 +85,17 @@ class Eval(Protocol):
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
eval_task_config: EvalTaskConfig, # type: ignore
|
task_config: EvalTaskConfig,
|
||||||
|
eval_task_id: Optional[str] = None,
|
||||||
) -> EvaluateResponse: ...
|
) -> EvaluateResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/status", method="GET")
|
@webmethod(route="/eval/job/status", method="GET")
|
||||||
async def job_status(self, job_id: str) -> Optional[JobStatus]: ...
|
async def job_status(
|
||||||
|
self, job_id: str, eval_task_id: str
|
||||||
|
) -> Optional[JobStatus]: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/cancel", method="POST")
|
@webmethod(route="/eval/job/cancel", method="POST")
|
||||||
async def job_cancel(self, job_id: str) -> None: ...
|
async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/eval/job/result", method="GET")
|
@webmethod(route="/eval/job/result", method="GET")
|
||||||
async def job_result(self, job_id: str) -> EvaluateResponse: ...
|
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ...
|
||||||
|
|
|
@ -48,8 +48,7 @@ class Scoring(Protocol):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse: ...
|
) -> ScoreBatchResponse: ...
|
||||||
|
|
||||||
|
@ -57,6 +56,5 @@ class Scoring(Protocol):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse: ...
|
) -> ScoreResponse: ...
|
||||||
|
|
|
@ -33,7 +33,7 @@ class ScoringConfigType(Enum):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LLMAsJudgeScoringFnParams(BaseModel):
|
class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringConfigType.llm_as_judge.value] = ( # type: ignore
|
type: Literal[ScoringConfigType.llm_as_judge.value] = (
|
||||||
ScoringConfigType.llm_as_judge.value
|
ScoringConfigType.llm_as_judge.value
|
||||||
)
|
)
|
||||||
judge_model: str
|
judge_model: str
|
||||||
|
@ -46,7 +46,7 @@ class LLMAsJudgeScoringFnParams(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RegexParserScoringFnParams(BaseModel):
|
class RegexParserScoringFnParams(BaseModel):
|
||||||
type: Literal[ScoringConfigType.regex_parser.value] = ( # type: ignore
|
type: Literal[ScoringConfigType.regex_parser.value] = (
|
||||||
ScoringConfigType.regex_parser.value
|
ScoringConfigType.regex_parser.value
|
||||||
)
|
)
|
||||||
parsing_regexes: Optional[List[str]] = Field(
|
parsing_regexes: Optional[List[str]] = Field(
|
||||||
|
@ -75,8 +75,8 @@ class ScoringFnDef(BaseModel):
|
||||||
return_type: ParamType = Field(
|
return_type: ParamType = Field(
|
||||||
description="The return type of the deterministic function",
|
description="The return type of the deterministic function",
|
||||||
)
|
)
|
||||||
params: Optional[ScoringFnParams] = Field( # type: ignore
|
params: Optional[ScoringFnParams] = Field(
|
||||||
description="The parameters for the scoring function for benchmark eval, we could override this for app eval",
|
description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
# We can optionally add information here to support packaging of code, etc.
|
# We can optionally add information here to support packaging of code, etc.
|
||||||
|
|
|
@ -8,6 +8,8 @@ import inspect
|
||||||
|
|
||||||
from typing import Any, Dict, List, Set
|
from typing import Any, Dict, List, Set
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import * # noqa: F403
|
from llama_stack.providers.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
@ -100,6 +102,12 @@ async def resolve_impls(
|
||||||
)
|
)
|
||||||
|
|
||||||
p = provider_registry[api][provider.provider_type]
|
p = provider_registry[api][provider.provider_type]
|
||||||
|
if p.deprecation_warning:
|
||||||
|
cprint(
|
||||||
|
f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}",
|
||||||
|
"red",
|
||||||
|
attrs=["bold"],
|
||||||
|
)
|
||||||
p.deps__ = [a.value for a in p.api_dependencies]
|
p.deps__ = [a.value for a in p.api_dependencies]
|
||||||
spec = ProviderWithSpec(
|
spec = ProviderWithSpec(
|
||||||
spec=p,
|
spec=p,
|
||||||
|
|
|
@ -216,18 +216,16 @@ class ScoringRouter(Scoring):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for fn_identifier in scoring_functions:
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(
|
||||||
fn_identifier
|
fn_identifier
|
||||||
).score_batch(
|
).score_batch(
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
scoring_functions=[fn_identifier],
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
scoring_params=scoring_params,
|
|
||||||
)
|
)
|
||||||
res.update(score_response.results)
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
@ -241,18 +239,16 @@ class ScoringRouter(Scoring):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
# look up and map each scoring function to its provider impl
|
# look up and map each scoring function to its provider impl
|
||||||
for fn_identifier in scoring_functions:
|
for fn_identifier in scoring_functions.keys():
|
||||||
score_response = await self.routing_table.get_provider_impl(
|
score_response = await self.routing_table.get_provider_impl(
|
||||||
fn_identifier
|
fn_identifier
|
||||||
).score(
|
).score(
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=[fn_identifier],
|
scoring_functions={fn_identifier: scoring_functions[fn_identifier]},
|
||||||
scoring_params=scoring_params,
|
|
||||||
)
|
)
|
||||||
res.update(score_response.results)
|
res.update(score_response.results)
|
||||||
|
|
||||||
|
@ -272,24 +268,21 @@ class EvalRouter(Eval):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_benchmark_eval(
|
async def run_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
eval_task_config: BenchmarkEvalTaskConfig,
|
benchmark_config: BenchmarkEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
eval_task_def: EvalTaskDef,
|
task: EvalTaskDef,
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
# NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals
|
return await self.routing_table.get_provider_impl(task.identifier).run_eval(
|
||||||
return await self.routing_table.get_provider_impl(
|
task=task,
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
task_config=task_config,
|
||||||
).run_eval(
|
|
||||||
eval_task_def=eval_task_def,
|
|
||||||
eval_task_config=eval_task_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@webmethod(route="/eval/evaluate_rows", method="POST")
|
@webmethod(route="/eval/evaluate_rows", method="POST")
|
||||||
|
@ -297,29 +290,42 @@ class EvalRouter(Eval):
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
eval_task_config: EvalTaskConfig, # type: ignore
|
task_config: EvalTaskConfig,
|
||||||
|
eval_task_id: Optional[str] = None,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
# NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task
|
# NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task
|
||||||
# We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier
|
# We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier
|
||||||
return await self.routing_table.get_provider_impl(
|
if eval_task_id is None:
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER
|
||||||
).evaluate_rows(
|
return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows(
|
||||||
input_rows=input_rows,
|
input_rows=input_rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
eval_task_config=eval_task_config,
|
task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
async def job_status(
|
||||||
return await self.routing_table.get_provider_impl(
|
self,
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
job_id: str,
|
||||||
).job_status(job_id)
|
eval_task_id: str,
|
||||||
|
) -> Optional[JobStatus]:
|
||||||
|
return await self.routing_table.get_provider_impl(eval_task_id).job_status(
|
||||||
|
job_id, eval_task_id
|
||||||
|
)
|
||||||
|
|
||||||
async def job_cancel(self, job_id: str) -> None:
|
async def job_cancel(
|
||||||
await self.routing_table.get_provider_impl(
|
self,
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
job_id: str,
|
||||||
).job_cancel(job_id)
|
eval_task_id: str,
|
||||||
|
) -> None:
|
||||||
|
await self.routing_table.get_provider_impl(eval_task_id).job_cancel(
|
||||||
|
job_id, eval_task_id
|
||||||
|
)
|
||||||
|
|
||||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
async def job_result(
|
||||||
return await self.routing_table.get_provider_impl(
|
self,
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER
|
job_id: str,
|
||||||
).job_result(job_id)
|
eval_task_id: str,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
return await self.routing_table.get_provider_impl(eval_task_id).job_result(
|
||||||
|
job_id, eval_task_id
|
||||||
|
)
|
||||||
|
|
|
@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import (
|
||||||
get_provider_registry,
|
get_provider_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
|
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
end_trace,
|
end_trace,
|
||||||
|
@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||||
from llama_stack.distribution.resolver import resolve_impls
|
from llama_stack.distribution.resolver import resolve_impls
|
||||||
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
|
||||||
|
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
|
@ -281,21 +279,8 @@ def main(
|
||||||
config = StackRunConfig(**yaml.safe_load(fp))
|
config = StackRunConfig(**yaml.safe_load(fp))
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
# instantiate kvstore for storing and retrieving distribution metadata
|
|
||||||
if config.metadata_store:
|
|
||||||
dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store))
|
|
||||||
else:
|
|
||||||
dist_kvstore = asyncio.run(
|
|
||||||
kvstore_impl(
|
|
||||||
SqliteKVStoreConfig(
|
|
||||||
db_path=(
|
|
||||||
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
|
||||||
).as_posix()
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config))
|
||||||
|
|
||||||
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry))
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
|
|
|
@ -9,9 +9,17 @@ from typing import Dict, List, Protocol
|
||||||
|
|
||||||
import pydantic
|
import pydantic
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import RoutableObjectWithProvider
|
from llama_stack.distribution.datatypes import (
|
||||||
|
RoutableObjectWithProvider,
|
||||||
|
StackRunConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import (
|
||||||
|
KVStore,
|
||||||
|
kvstore_impl,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class DistributionRegistry(Protocol):
|
class DistributionRegistry(Protocol):
|
||||||
|
@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry):
|
||||||
self.cache[obj.identifier].append(obj)
|
self.cache[obj.identifier].append(obj)
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
|
|
||||||
|
async def create_dist_registry(
|
||||||
|
config: StackRunConfig,
|
||||||
|
) -> tuple[CachedDiskDistributionRegistry, KVStore]:
|
||||||
|
# instantiate kvstore for storing and retrieving distribution metadata
|
||||||
|
if config.metadata_store:
|
||||||
|
dist_kvstore = await kvstore_impl(config.metadata_store)
|
||||||
|
else:
|
||||||
|
dist_kvstore = await kvstore_impl(
|
||||||
|
SqliteKVStoreConfig(
|
||||||
|
db_path=(
|
||||||
|
DISTRIBS_BASE_DIR / config.image_name / "kvstore.db"
|
||||||
|
).as_posix()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore
|
||||||
|
|
|
@ -90,6 +90,10 @@ class ProviderSpec(BaseModel):
|
||||||
default_factory=list,
|
default_factory=list,
|
||||||
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
description="Higher-level API surfaces may depend on other providers to provide their functionality",
|
||||||
)
|
)
|
||||||
|
deprecation_warning: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="If this provider is deprecated, specify the warning message here",
|
||||||
|
)
|
||||||
|
|
||||||
# used internally by the resolver; this is a hack for now
|
# used internally by the resolver; this is a hack for now
|
||||||
deps__: List[str] = Field(default_factory=list)
|
deps__: List[str] = Field(default_factory=list)
|
||||||
|
|
|
@ -4,10 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
from llama_stack.providers.utils.kvstore import KVStoreConfig
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceAgentsImplConfig(BaseModel):
|
class MetaReferenceAgentsImplConfig(BaseModel):
|
|
@ -11,9 +11,8 @@ from datetime import datetime
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from llama_stack.apis.agents import * # noqa: F403
|
from llama_stack.apis.agents import * # noqa: F403
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class AgentSessionInfo(BaseModel):
|
class AgentSessionInfo(BaseModel):
|
|
@ -10,14 +10,13 @@ from jinja2 import Template
|
||||||
from llama_models.llama3.api import * # noqa: F403
|
from llama_models.llama3.api import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from termcolor import cprint # noqa: F401
|
|
||||||
|
|
||||||
from llama_stack.apis.agents import (
|
from llama_stack.apis.agents import (
|
||||||
DefaultMemoryQueryGeneratorConfig,
|
DefaultMemoryQueryGeneratorConfig,
|
||||||
LLMMemoryQueryGeneratorConfig,
|
LLMMemoryQueryGeneratorConfig,
|
||||||
MemoryQueryGenerator,
|
MemoryQueryGenerator,
|
||||||
MemoryQueryGeneratorConfig,
|
MemoryQueryGeneratorConfig,
|
||||||
)
|
)
|
||||||
|
from termcolor import cprint # noqa: F401
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
|
@ -9,8 +9,7 @@ from typing import List
|
||||||
from llama_stack.apis.inference import Message
|
from llama_stack.apis.inference import Message
|
||||||
from llama_stack.apis.safety import * # noqa: F403
|
from llama_stack.apis.safety import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin
|
from ..safety import ShieldRunnerMixin
|
||||||
|
|
||||||
from .builtin import BaseTool
|
from .builtin import BaseTool
|
||||||
|
|
||||||
|
|
|
@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F401, F403
|
from llama_stack.apis.inference import * # noqa: F401, F403
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceConfig(BaseModel):
|
class MetaReferenceInferenceConfig(BaseModel):
|
|
@ -35,13 +35,12 @@ from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
augment_content_with_response_format_prompt,
|
augment_content_with_response_format_prompt,
|
||||||
chat_completion_request_to_messages,
|
chat_completion_request_to_messages,
|
||||||
)
|
)
|
||||||
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
|
|
||||||
from .config import (
|
from .config import (
|
||||||
Fp8QuantizationConfig,
|
Fp8QuantizationConfig,
|
|
@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_src_rank,
|
get_model_parallel_src_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest
|
|
||||||
|
|
||||||
from .generation import TokenResult
|
from .generation import TokenResult
|
||||||
|
|
||||||
|
|
|
@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat
|
||||||
from llama_models.llama3.api.args import ModelArgs
|
from llama_models.llama3.api.args import ModelArgs
|
||||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||||
from llama_models.sku_list import resolve_model
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import QuantizationType
|
||||||
|
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
|
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from ..config import MetaReferenceQuantizedInferenceConfig
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.inference.config import (
|
|
||||||
MetaReferenceQuantizedInferenceConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper(
|
def swiglu_wrapper(
|
|
@ -5,9 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field, field_validator
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -5,13 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||||
from llama_stack.providers.utils.kvstore.config import (
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
KVStoreConfig,
|
KVStoreConfig,
|
||||||
SqliteKVStoreConfig,
|
SqliteKVStoreConfig,
|
||||||
)
|
)
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
|
@ -8,10 +8,11 @@ import logging
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import faiss
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
|
@ -7,11 +7,17 @@ from enum import Enum
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from .....apis.common.job_types import Job
|
from .....apis.common.job_types import Job
|
||||||
from .....apis.eval.eval import BenchmarkEvalTaskConfig
|
from .....apis.eval.eval import (
|
||||||
|
AppEvalTaskConfig,
|
||||||
|
BenchmarkEvalTaskConfig,
|
||||||
|
Eval,
|
||||||
|
EvalTaskConfig,
|
||||||
|
EvaluateResponse,
|
||||||
|
JobStatus,
|
||||||
|
)
|
||||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
from llama_stack.apis.datasetio import DatasetIO
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import Datasets
|
from llama_stack.apis.datasets import Datasets
|
||||||
from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus
|
|
||||||
from llama_stack.apis.eval_tasks import EvalTaskDef
|
from llama_stack.apis.eval_tasks import EvalTaskDef
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
|
@ -19,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate
|
||||||
|
|
||||||
from .config import MetaReferenceEvalConfig
|
from .config import MetaReferenceEvalConfig
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: this is the default eval task identifier for app eval
|
||||||
|
# it is used to make the router work for all app evals
|
||||||
|
# For app eval using other eval providers, the eval task identifier will be different
|
||||||
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
|
DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval"
|
||||||
|
|
||||||
|
|
||||||
|
@ -88,21 +98,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}"
|
||||||
)
|
)
|
||||||
|
|
||||||
async def run_benchmark_eval(
|
async def run_benchmark(
|
||||||
self,
|
self,
|
||||||
benchmark_id: str,
|
benchmark_id: str,
|
||||||
eval_task_config: BenchmarkEvalTaskConfig,
|
benchmark_config: BenchmarkEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
raise NotImplementedError("Benchmark eval is not implemented yet")
|
raise NotImplementedError("Benchmark eval is not implemented yet")
|
||||||
|
|
||||||
async def run_eval(
|
async def run_eval(
|
||||||
self,
|
self,
|
||||||
eval_task_def: EvalTaskDef,
|
task: EvalTaskDef,
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: AppEvalTaskConfig,
|
||||||
) -> Job:
|
) -> Job:
|
||||||
dataset_id = eval_task_def.dataset_id
|
dataset_id = task.dataset_id
|
||||||
candidate = eval_task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
scoring_functions = eval_task_def.scoring_functions
|
scoring_functions = task.scoring_functions
|
||||||
|
|
||||||
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_eval_input_dataset_schema(dataset_id=dataset_id)
|
||||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||||
|
@ -112,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
res = await self.evaluate_rows(
|
res = await self.evaluate_rows(
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
eval_task_config=eval_task_config,
|
task_config=task_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: currently needs to wait for generation before returning
|
# TODO: currently needs to wait for generation before returning
|
||||||
|
@ -125,9 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: List[str],
|
||||||
eval_task_config: EvalTaskConfig,
|
task_config: EvalTaskConfig,
|
||||||
|
eval_task_id: Optional[str] = None,
|
||||||
) -> EvaluateResponse:
|
) -> EvaluateResponse:
|
||||||
candidate = eval_task_config.eval_candidate
|
candidate = task_config.eval_candidate
|
||||||
if candidate.type == "agent":
|
if candidate.type == "agent":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Evaluation with generation has not been implemented for agents"
|
"Evaluation with generation has not been implemented for agents"
|
||||||
|
@ -179,23 +190,33 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate):
|
||||||
for input_r, generated_r in zip(input_rows, generations)
|
for input_r, generated_r in zip(input_rows, generations)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if task_config.type == "app" and task_config.scoring_params is not None:
|
||||||
|
scoring_functions_dict = {
|
||||||
|
scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None)
|
||||||
|
for scoring_fn_id in scoring_functions
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
scoring_functions_dict = {
|
||||||
|
scoring_fn_id: None for scoring_fn_id in scoring_functions
|
||||||
|
}
|
||||||
|
|
||||||
score_response = await self.scoring_api.score(
|
score_response = await self.scoring_api.score(
|
||||||
input_rows=score_input_rows, scoring_functions=scoring_functions
|
input_rows=score_input_rows, scoring_functions=scoring_functions_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
return EvaluateResponse(generations=generations, scores=score_response.results)
|
return EvaluateResponse(generations=generations, scores=score_response.results)
|
||||||
|
|
||||||
async def job_status(self, job_id: str) -> Optional[JobStatus]:
|
async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]:
|
||||||
if job_id in self.jobs:
|
if job_id in self.jobs:
|
||||||
return JobStatus.completed
|
return JobStatus.completed
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def job_cancel(self, job_id: str) -> None:
|
async def job_cancel(self, job_id: str, eval_task_id: str) -> None:
|
||||||
raise NotImplementedError("Job cancel is not implemented yet")
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
async def job_result(self, job_id: str) -> EvaluateResponse:
|
async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse:
|
||||||
status = await self.job_status(job_id)
|
status = await self.job_status(job_id, eval_task_id)
|
||||||
if not status or status != JobStatus.completed:
|
if not status or status != JobStatus.completed:
|
||||||
raise ValueError(f"Job is not completed, Status: {status.value}")
|
raise ValueError(f"Job is not completed, Status: {status.value}")
|
||||||
|
|
||||||
|
|
|
@ -1,73 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
import tempfile
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef
|
|
||||||
from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig
|
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl
|
|
||||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TestFaissMemoryImpl:
|
|
||||||
@pytest.fixture
|
|
||||||
def faiss_impl(self):
|
|
||||||
# Create a temporary SQLite database file
|
|
||||||
temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False)
|
|
||||||
config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name))
|
|
||||||
return FaissMemoryImpl(config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_initialize(self, faiss_impl):
|
|
||||||
# Test empty initialization
|
|
||||||
await faiss_impl.initialize()
|
|
||||||
assert len(faiss_impl.cache) == 0
|
|
||||||
|
|
||||||
# Test initialization with existing banks
|
|
||||||
bank = VectorMemoryBankDef(
|
|
||||||
identifier="test_bank",
|
|
||||||
type=MemoryBankType.vector.value,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register a bank and reinitialize to test loading
|
|
||||||
await faiss_impl.register_memory_bank(bank)
|
|
||||||
|
|
||||||
# Create new instance to test initialization with existing data
|
|
||||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
|
||||||
await new_impl.initialize()
|
|
||||||
|
|
||||||
assert len(new_impl.cache) == 1
|
|
||||||
assert "test_bank" in new_impl.cache
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_register_memory_bank(self, faiss_impl):
|
|
||||||
bank = VectorMemoryBankDef(
|
|
||||||
identifier="test_bank",
|
|
||||||
type=MemoryBankType.vector.value,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
chunk_size_in_tokens=512,
|
|
||||||
overlap_size_in_tokens=64,
|
|
||||||
)
|
|
||||||
|
|
||||||
await faiss_impl.initialize()
|
|
||||||
await faiss_impl.register_memory_bank(bank)
|
|
||||||
|
|
||||||
assert "test_bank" in faiss_impl.cache
|
|
||||||
assert faiss_impl.cache["test_bank"].bank == bank
|
|
||||||
|
|
||||||
# Verify persistence
|
|
||||||
new_impl = FaissMemoryImpl(faiss_impl.config)
|
|
||||||
await new_impl.initialize()
|
|
||||||
assert "test_bank" in new_impl.cache
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
|
@ -89,8 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score_batch(
|
async def score_batch(
|
||||||
self,
|
self,
|
||||||
dataset_id: str,
|
dataset_id: str,
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
save_results_dataset: bool = False,
|
save_results_dataset: bool = False,
|
||||||
) -> ScoreBatchResponse:
|
) -> ScoreBatchResponse:
|
||||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||||
|
@ -101,7 +100,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
res = await self.score(
|
res = await self.score(
|
||||||
input_rows=all_rows.rows,
|
input_rows=all_rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
scoring_params=scoring_params,
|
|
||||||
)
|
)
|
||||||
if save_results_dataset:
|
if save_results_dataset:
|
||||||
# TODO: persist and register dataset on to server for reading
|
# TODO: persist and register dataset on to server for reading
|
||||||
|
@ -115,17 +113,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||||
async def score(
|
async def score(
|
||||||
self,
|
self,
|
||||||
input_rows: List[Dict[str, Any]],
|
input_rows: List[Dict[str, Any]],
|
||||||
scoring_functions: List[str],
|
scoring_functions: Optional[Dict[str, ScoringFnParams]] = None,
|
||||||
scoring_params: Optional[Dict[str, ScoringFnParams]] = None,
|
|
||||||
) -> ScoreResponse:
|
) -> ScoreResponse:
|
||||||
res = {}
|
res = {}
|
||||||
for scoring_fn_id in scoring_functions:
|
for scoring_fn_id in scoring_functions.keys():
|
||||||
if scoring_fn_id not in self.scoring_fn_id_impls:
|
if scoring_fn_id not in self.scoring_fn_id_impls:
|
||||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||||
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
scoring_fn = self.scoring_fn_id_impls[scoring_fn_id]
|
||||||
scoring_fn_params = None
|
scoring_fn_params = scoring_functions.get(scoring_fn_id, None)
|
||||||
if scoring_params is not None:
|
|
||||||
scoring_fn_params = scoring_params.get(scoring_fn_id, None)
|
|
||||||
score_results = await scoring_fn.score(
|
score_results = await scoring_fn.score(
|
||||||
input_rows, scoring_fn_id, scoring_fn_params
|
input_rows, scoring_fn_id, scoring_fn_params
|
||||||
)
|
)
|
||||||
|
|
|
@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
]
|
]
|
||||||
+ kvstore_dependencies(),
|
+ kvstore_dependencies(),
|
||||||
module="llama_stack.providers.inline.meta_reference.agents",
|
module="llama_stack.providers.inline.agents.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig",
|
config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
Api.safety,
|
Api.safety,
|
||||||
|
|
|
@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=META_REFERENCE_DEPS,
|
pip_packages=META_REFERENCE_DEPS,
|
||||||
module="llama_stack.providers.inline.meta_reference.inference",
|
module="llama_stack.providers.inline.inference.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
"torchao==0.5.0",
|
"torchao==0.5.0",
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
module="llama_stack.providers.inline.meta_reference.inference",
|
module="llama_stack.providers.inline.inference.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="vllm",
|
||||||
|
pip_packages=[
|
||||||
|
"vllm",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.inline.inference.vllm",
|
||||||
|
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
|
@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.inference.together",
|
module="llama_stack.providers.remote.inference.together",
|
||||||
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.inference,
|
|
||||||
provider_type="vllm",
|
|
||||||
pip_packages=[
|
|
||||||
"vllm",
|
|
||||||
],
|
|
||||||
module="llama_stack.providers.inline.vllm",
|
|
||||||
config_class="llama_stack.providers.inline.vllm.VLLMConfig",
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
provider_type="meta-reference",
|
provider_type="meta-reference",
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.meta_reference.memory",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
|
deprecation_warning="Please use the `faiss` provider instead.",
|
||||||
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.memory,
|
||||||
|
provider_type="faiss",
|
||||||
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
|
|
@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
"transformers",
|
"transformers",
|
||||||
"torch --index-url https://download.pytorch.org/whl/cpu",
|
"torch --index-url https://download.pytorch.org/whl/cpu",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.meta_reference.safety",
|
module="llama_stack.providers.inline.safety.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig",
|
config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig",
|
||||||
api_dependencies=[
|
api_dependencies=[
|
||||||
Api.inference,
|
Api.inference,
|
||||||
],
|
],
|
||||||
|
@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"codeshield",
|
"codeshield",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.meta_reference.codeshield",
|
module="llama_stack.providers.inline.safety.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig",
|
config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig",
|
||||||
api_dependencies=[],
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel):
|
||||||
default="https://api.fireworks.ai/inference",
|
default="https://api.fireworks.ai/inference",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: str = Field(
|
api_key: Optional[str] = Field(
|
||||||
default="",
|
default=None,
|
||||||
description="The Fireworks.ai API Key",
|
description="The Fireworks.ai API Key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,12 +9,11 @@ from typing import AsyncGenerator
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
|
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import Message
|
from llama_models.llama3.api.datatypes import Message
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
|
||||||
from .config import FireworksImplConfig
|
from .config import FireworksImplConfig
|
||||||
|
|
||||||
|
|
||||||
FIREWORKS_SUPPORTED_MODELS = {
|
FIREWORKS_SUPPORTED_MODELS = {
|
||||||
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
"Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct",
|
||||||
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
"Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct",
|
||||||
|
@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = {
|
||||||
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
"Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct",
|
||||||
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
"Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct",
|
||||||
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
"Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct",
|
||||||
|
"Llama-Guard-3-8B": "fireworks/llama-guard-3-8b",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
class FireworksInferenceAdapter(
|
||||||
|
ModelRegistryHelper, Inference, NeedsRequestProviderData
|
||||||
|
):
|
||||||
def __init__(self, config: FireworksImplConfig) -> None:
|
def __init__(self, config: FireworksImplConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(
|
ModelRegistryHelper.__init__(
|
||||||
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS
|
||||||
|
@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
return
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _get_client(self) -> Fireworks:
|
||||||
|
fireworks_api_key = None
|
||||||
|
if self.config.api_key is not None:
|
||||||
|
fireworks_api_key = self.config.api_key
|
||||||
|
else:
|
||||||
|
provider_data = self.get_request_provider_data()
|
||||||
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
|
raise ValueError(
|
||||||
|
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||||
|
)
|
||||||
|
fireworks_api_key = provider_data.fireworks_api_key
|
||||||
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
stream=stream,
|
stream=stream,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
client = Fireworks(api_key=self.config.api_key)
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_completion(request, client)
|
return self._stream_completion(request)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_completion(request, client)
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
async def _nonstream_completion(
|
async def _nonstream_completion(
|
||||||
self, request: CompletionRequest, client: Fireworks
|
self, request: CompletionRequest
|
||||||
) -> CompletionResponse:
|
) -> CompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = await client.completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_completion_response(r, self.formatter)
|
return process_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_completion(
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
self, request: CompletionRequest, client: Fireworks
|
|
||||||
) -> AsyncGenerator:
|
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
stream = client.completion.acreate(**params)
|
# Wrapper for async generator similar
|
||||||
|
async def _to_async_generator():
|
||||||
|
stream = self._get_client().completion.create(**params)
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
def _build_options(
|
||||||
|
self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat
|
||||||
|
) -> dict:
|
||||||
|
options = get_sampling_options(sampling_params)
|
||||||
|
options.setdefault("max_tokens", 512)
|
||||||
|
|
||||||
|
if fmt:
|
||||||
|
if fmt.type == ResponseFormatType.json_schema.value:
|
||||||
|
options["response_format"] = {
|
||||||
|
"type": "json_object",
|
||||||
|
"schema": fmt.json_schema,
|
||||||
|
}
|
||||||
|
elif fmt.type == ResponseFormatType.grammar.value:
|
||||||
|
options["response_format"] = {
|
||||||
|
"type": "grammar",
|
||||||
|
"grammar": fmt.bnf,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown response format {fmt.type}")
|
||||||
|
|
||||||
|
return options
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
client = Fireworks(api_key=self.config.api_key)
|
|
||||||
if stream:
|
if stream:
|
||||||
return self._stream_chat_completion(request, client)
|
return self._stream_chat_completion(request)
|
||||||
else:
|
else:
|
||||||
return await self._nonstream_chat_completion(request, client)
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
async def _nonstream_chat_completion(
|
async def _nonstream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
if "messages" in params:
|
if "messages" in params:
|
||||||
r = await client.chat.completions.acreate(**params)
|
r = await self._get_client().chat.completions.acreate(**params)
|
||||||
else:
|
else:
|
||||||
r = await client.completion.acreate(**params)
|
r = await self._get_client().completion.acreate(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
return process_chat_completion_response(r, self.formatter)
|
||||||
|
|
||||||
async def _stream_chat_completion(
|
async def _stream_chat_completion(
|
||||||
self, request: ChatCompletionRequest, client: Fireworks
|
self, request: ChatCompletionRequest
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
|
||||||
if "messages" in params:
|
async def _to_async_generator():
|
||||||
stream = client.chat.completions.acreate(**params)
|
if "messages" in params:
|
||||||
else:
|
stream = await self._get_client().chat.completions.acreate(**params)
|
||||||
stream = client.completion.acreate(**params)
|
else:
|
||||||
|
stream = self._get_client().completion.create(**params)
|
||||||
|
for chunk in stream:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
async for chunk in process_chat_completion_stream_response(
|
async for chunk in process_chat_completion_stream_response(
|
||||||
stream, self.formatter
|
stream, self.formatter
|
||||||
):
|
):
|
||||||
|
@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
input_dict["prompt"] = chat_completion_request_to_prompt(
|
input_dict["prompt"] = chat_completion_request_to_prompt(
|
||||||
request, self.formatter
|
request, self.formatter
|
||||||
)
|
)
|
||||||
elif isinstance(request, CompletionRequest):
|
else:
|
||||||
assert (
|
assert (
|
||||||
not media_present
|
not media_present
|
||||||
), "Fireworks does not support media for Completion requests"
|
), "Fireworks does not support media for Completion requests"
|
||||||
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
input_dict["prompt"] = completion_request_to_prompt(request, self.formatter)
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown request type {type(request)}")
|
|
||||||
|
|
||||||
# Fireworks always prepends with BOS
|
# Fireworks always prepends with BOS
|
||||||
if "prompt" in input_dict:
|
if "prompt" in input_dict:
|
||||||
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
if input_dict["prompt"].startswith("<|begin_of_text|>"):
|
||||||
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :]
|
||||||
|
|
||||||
options = get_sampling_options(request.sampling_params)
|
|
||||||
options.setdefault("max_tokens", 512)
|
|
||||||
|
|
||||||
if fmt := request.response_format:
|
|
||||||
if fmt.type == ResponseFormatType.json_schema.value:
|
|
||||||
options["response_format"] = {
|
|
||||||
"type": "json_object",
|
|
||||||
"schema": fmt.json_schema,
|
|
||||||
}
|
|
||||||
elif fmt.type == ResponseFormatType.grammar.value:
|
|
||||||
options["response_format"] = {
|
|
||||||
"type": "grammar",
|
|
||||||
"grammar": fmt.bnf,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown response format {fmt.type}")
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"model": self.map_to_provider_model(request.model),
|
"model": self.map_to_provider_model(request.model),
|
||||||
**input_dict,
|
**input_dict,
|
||||||
"stream": request.stream,
|
"stream": request.stream,
|
||||||
**options,
|
**self._build_options(request.sampling_params, request.response_format),
|
||||||
}
|
}
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
|
|
|
@ -4,9 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from .config import TogetherImplConfig
|
from .config import TogetherImplConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TogetherProviderDataValidator(BaseModel):
|
||||||
|
together_api_key: str
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
async def get_adapter_impl(config: TogetherImplConfig, _deps):
|
||||||
from .together import TogetherInferenceAdapter
|
from .together import TogetherInferenceAdapter
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.meta_reference.agents import (
|
from llama_stack.providers.inline.agents.meta_reference import (
|
||||||
MetaReferenceAgentsImplConfig,
|
MetaReferenceAgentsImplConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -52,7 +52,7 @@ class Testeval:
|
||||||
response = await eval_impl.evaluate_rows(
|
response = await eval_impl.evaluate_rows(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
eval_task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
@ -76,13 +76,13 @@ class Testeval:
|
||||||
]
|
]
|
||||||
|
|
||||||
response = await eval_impl.run_eval(
|
response = await eval_impl.run_eval(
|
||||||
eval_task_def=EvalTaskDef(
|
task=EvalTaskDef(
|
||||||
# NOTE: this is needed to make the router work for all app evals
|
# NOTE: this is needed to make the router work for all app evals
|
||||||
identifier="meta-reference::app_eval",
|
identifier="meta-reference::app_eval",
|
||||||
dataset_id="test_dataset_for_eval",
|
dataset_id="test_dataset_for_eval",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
),
|
),
|
||||||
eval_task_config=AppEvalTaskConfig(
|
task_config=AppEvalTaskConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
model="Llama3.2-3B-Instruct",
|
model="Llama3.2-3B-Instruct",
|
||||||
sampling_params=SamplingParams(),
|
sampling_params=SamplingParams(),
|
||||||
|
@ -90,9 +90,13 @@ class Testeval:
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert response.job_id == "0"
|
assert response.job_id == "0"
|
||||||
job_status = await eval_impl.job_status(response.job_id)
|
job_status = await eval_impl.job_status(
|
||||||
|
response.job_id, "meta-reference::app_eval"
|
||||||
|
)
|
||||||
assert job_status and job_status.value == "completed"
|
assert job_status and job_status.value == "completed"
|
||||||
eval_response = await eval_impl.job_result(response.job_id)
|
eval_response = await eval_impl.job_result(
|
||||||
|
response.job_id, "meta-reference::app_eval"
|
||||||
|
)
|
||||||
|
|
||||||
assert eval_response is not None
|
assert eval_response is not None
|
||||||
assert len(eval_response.generations) == 5
|
assert len(eval_response.generations) == 5
|
||||||
|
|
|
@ -10,7 +10,7 @@ import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.meta_reference.inference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
from llama_stack.providers.remote.memory.pgvector import PGVectorConfig
|
||||||
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
from llama_stack.providers.remote.memory.weaviate import WeaviateConfig
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.meta_reference.safety import (
|
from llama_stack.providers.inline.safety.meta_reference import (
|
||||||
LlamaGuardShieldConfig,
|
LlamaGuardShieldConfig,
|
||||||
SafetyConfig,
|
SafetyConfig,
|
||||||
)
|
)
|
||||||
|
|
|
@ -44,10 +44,10 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
scoring_functions = [
|
scoring_functions = {
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
"meta-reference::llm_as_judge_8b_correctness": None,
|
||||||
"meta-reference::equality",
|
"meta-reference::equality": None,
|
||||||
]
|
}
|
||||||
response = await scoring_impl.score(
|
response = await scoring_impl.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
|
@ -83,7 +83,7 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
assert len(rows.rows) == 3
|
assert len(rows.rows) == 3
|
||||||
|
|
||||||
params = {
|
scoring_functions = {
|
||||||
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
"meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams(
|
||||||
judge_model="Llama3.1-405B-Instruct",
|
judge_model="Llama3.1-405B-Instruct",
|
||||||
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
prompt_template="Output a number response in the following format: Score: <number>, where <number> is the number between 0 and 9.",
|
||||||
|
@ -91,13 +91,9 @@ class TestScoring:
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
scoring_functions = [
|
|
||||||
"meta-reference::llm_as_judge_8b_correctness",
|
|
||||||
]
|
|
||||||
response = await scoring_impl.score(
|
response = await scoring_impl.score(
|
||||||
input_rows=rows.rows,
|
input_rows=rows.rows,
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
scoring_params=params,
|
|
||||||
)
|
)
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
|
@ -108,7 +104,6 @@ class TestScoring:
|
||||||
response = await scoring_impl.score_batch(
|
response = await scoring_impl.score_batch(
|
||||||
dataset_id="test_dataset",
|
dataset_id="test_dataset",
|
||||||
scoring_functions=scoring_functions,
|
scoring_functions=scoring_functions,
|
||||||
scoring_params=params,
|
|
||||||
)
|
)
|
||||||
assert len(response.results) == len(scoring_functions)
|
assert len(response.results) == len(scoring_functions)
|
||||||
for x in scoring_functions:
|
for x in scoring_functions:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue