diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index a92442dc1..79701d926 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,17 +1,14 @@ # What does this PR do? -Closes # (issue) +In short, provide a summary of what this PR does and why. Usually, the relevant context should be present in a linked issue. + +- [ ] Addresses issue (#issue) ## Feature/Issue validation/testing/test plan -Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced. -Please also list any relevant details for your test configuration or test plan. - -- [ ] Test A -Logs for Test A - -- [ ] Test B -Logs for Test B +Please describe: + - tests you ran to verify your changes with result summaries. + - provide instructions so it can be reproduced. ## Sources @@ -20,12 +17,10 @@ Please link relevant resources if necessary. ## Before submitting -- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). -- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), - Pull Request section? -- [ ] Was this discussed/approved via a Github issue? Please add a link - to it if that's the case. -- [ ] Did you make sure to update the documentation with your changes? -- [ ] Did you write any new necessary tests? -Thanks for contributing 🎉! +- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). +- [ ] Ran pre-commit to handle lint / formatting issues. +- [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), + Pull Request section? +- [ ] Updated relevant documentation. +- [ ] Wrote necessary unit or integration tests. diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index d44ca4f0b..6aa4cae34 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -38,14 +38,16 @@ EvalCandidate = Annotated[ @json_schema_type class BenchmarkEvalTaskConfig(BaseModel): - eval_candidate: EvalCandidate # type: ignore + type: Literal["benchmark"] = "benchmark" + eval_candidate: EvalCandidate @json_schema_type class AppEvalTaskConfig(BaseModel): - eval_candidate: EvalCandidate # type: ignore - scoring_params: Dict[str, ScoringFnParams] = Field( # type: ignore - description="Map between scoring function id and parameters", + type: Literal["app"] = "app" + eval_candidate: EvalCandidate + scoring_params: Dict[str, ScoringFnParams] = Field( + description="Map between scoring function id and parameters for each scoring function you want to run", default_factory=dict, ) # we could optinally add any specific dataset config here @@ -64,18 +66,18 @@ class EvaluateResponse(BaseModel): class Eval(Protocol): - @webmethod(route="/eval/run_benchmark_eval", method="POST") - async def run_benchmark_eval( + @webmethod(route="/eval/run_benchmark", method="POST") + async def run_benchmark( self, benchmark_id: str, - eval_task_config: BenchmarkEvalTaskConfig, + benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: ... @webmethod(route="/eval/run_eval", method="POST") async def run_eval( self, - eval_task_def: EvalTaskDef, - eval_task_config: EvalTaskConfig, + task: EvalTaskDef, + task_config: AppEvalTaskConfig, ) -> Job: ... @webmethod(route="/eval/evaluate_rows", method="POST") @@ -83,14 +85,17 @@ class Eval(Protocol): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, # type: ignore + task_config: EvalTaskConfig, + eval_task_id: Optional[str] = None, ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status(self, job_id: str) -> Optional[JobStatus]: ... + async def job_status( + self, job_id: str, eval_task_id: str + ) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") - async def job_cancel(self, job_id: str) -> None: ... + async def job_cancel(self, job_id: str, eval_task_id: str) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result(self, job_id: str) -> EvaluateResponse: ... + async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index a518bd806..a68582057 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,8 +48,7 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @@ -57,6 +56,5 @@ class Scoring(Protocol): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 341f84c36..140376242 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -33,7 +33,7 @@ class ScoringConfigType(Enum): @json_schema_type class LLMAsJudgeScoringFnParams(BaseModel): - type: Literal[ScoringConfigType.llm_as_judge.value] = ( # type: ignore + type: Literal[ScoringConfigType.llm_as_judge.value] = ( ScoringConfigType.llm_as_judge.value ) judge_model: str @@ -46,7 +46,7 @@ class LLMAsJudgeScoringFnParams(BaseModel): @json_schema_type class RegexParserScoringFnParams(BaseModel): - type: Literal[ScoringConfigType.regex_parser.value] = ( # type: ignore + type: Literal[ScoringConfigType.regex_parser.value] = ( ScoringConfigType.regex_parser.value ) parsing_regexes: Optional[List[str]] = Field( @@ -75,8 +75,8 @@ class ScoringFnDef(BaseModel): return_type: ParamType = Field( description="The return type of the deterministic function", ) - params: Optional[ScoringFnParams] = Field( # type: ignore - description="The parameters for the scoring function for benchmark eval, we could override this for app eval", + params: Optional[ScoringFnParams] = Field( + description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", default=None, ) # We can optionally add information here to support packaging of code, etc. diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 7a8d1dfee..aac7ae5b6 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,6 +8,8 @@ import inspect from typing import Any, Dict, List, Set +from termcolor import cprint + from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 @@ -100,6 +102,12 @@ async def resolve_impls( ) p = provider_registry[api][provider.provider_type] + if p.deprecation_warning: + cprint( + f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", + "red", + attrs=["bold"], + ) p.deps__ = [a.value for a in p.api_dependencies] spec = ProviderWithSpec( spec=p, diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 18c78b06c..06d50bd65 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -216,18 +216,16 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score_batch( dataset_id=dataset_id, - scoring_functions=[fn_identifier], - scoring_params=scoring_params, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) @@ -241,18 +239,16 @@ class ScoringRouter(Scoring): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score( input_rows=input_rows, - scoring_functions=[fn_identifier], - scoring_params=scoring_params, + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) @@ -272,24 +268,21 @@ class EvalRouter(Eval): async def shutdown(self) -> None: pass - async def run_benchmark_eval( + async def run_benchmark( self, benchmark_id: str, - eval_task_config: BenchmarkEvalTaskConfig, + benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: pass async def run_eval( self, - eval_task_def: EvalTaskDef, - eval_task_config: EvalTaskConfig, + task: EvalTaskDef, + task_config: AppEvalTaskConfig, ) -> Job: - # NOTE: We need to use DEFAULT_EVAL_TASK_IDENTIFIER to make the router work for all app evals - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).run_eval( - eval_task_def=eval_task_def, - eval_task_config=eval_task_config, + return await self.routing_table.get_provider_impl(task.identifier).run_eval( + task=task, + task_config=task_config, ) @webmethod(route="/eval/evaluate_rows", method="POST") @@ -297,29 +290,42 @@ class EvalRouter(Eval): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, # type: ignore + task_config: EvalTaskConfig, + eval_task_id: Optional[str] = None, ) -> EvaluateResponse: # NOTE: This is to deal with the case where we do not pre-register an eval benchmark_task # We use default DEFAULT_EVAL_TASK_IDENTIFIER as identifier - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).evaluate_rows( + if eval_task_id is None: + eval_task_id = DEFAULT_EVAL_TASK_IDENTIFIER + return await self.routing_table.get_provider_impl(eval_task_id).evaluate_rows( input_rows=input_rows, scoring_functions=scoring_functions, - eval_task_config=eval_task_config, + task_config=task_config, ) - async def job_status(self, job_id: str) -> Optional[JobStatus]: - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).job_status(job_id) + async def job_status( + self, + job_id: str, + eval_task_id: str, + ) -> Optional[JobStatus]: + return await self.routing_table.get_provider_impl(eval_task_id).job_status( + job_id, eval_task_id + ) - async def job_cancel(self, job_id: str) -> None: - await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).job_cancel(job_id) + async def job_cancel( + self, + job_id: str, + eval_task_id: str, + ) -> None: + await self.routing_table.get_provider_impl(eval_task_id).job_cancel( + job_id, eval_task_id + ) - async def job_result(self, job_id: str) -> EvaluateResponse: - return await self.routing_table.get_provider_impl( - DEFAULT_EVAL_TASK_IDENTIFIER - ).job_result(job_id) + async def job_result( + self, + job_id: str, + eval_task_id: str, + ) -> EvaluateResponse: + return await self.routing_table.get_provider_impl(eval_task_id).job_result( + job_id, eval_task_id + ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16c0fd0e0..143813780 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import ( get_provider_registry, ) -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -281,21 +279,8 @@ def main( config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() - # instantiate kvstore for storing and retrieving distribution metadata - if config.metadata_store: - dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store)) - else: - dist_kvstore = asyncio.run( - kvstore_impl( - SqliteKVStoreConfig( - db_path=( - DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" - ).as_posix() - ) - ) - ) - dist_registry = CachedDiskDistributionRegistry(dist_kvstore) + dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 994fb475c..897bb90d0 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -9,9 +9,17 @@ from typing import Dict, List, Protocol import pydantic -from llama_stack.distribution.datatypes import RoutableObjectWithProvider +from llama_stack.distribution.datatypes import ( + RoutableObjectWithProvider, + StackRunConfig, +) +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.kvstore import ( + KVStore, + kvstore_impl, + SqliteKVStoreConfig, +) class DistributionRegistry(Protocol): @@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): self.cache[obj.identifier].append(obj) return success + + +async def create_dist_registry( + config: StackRunConfig, +) -> tuple[CachedDiskDistributionRegistry, KVStore]: + # instantiate kvstore for storing and retrieving distribution metadata + if config.metadata_store: + dist_kvstore = await kvstore_impl(config.metadata_store) + else: + dist_kvstore = await kvstore_impl( + SqliteKVStoreConfig( + db_path=( + DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" + ).as_posix() + ) + ) + + return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index c4c602628..0f82ca592 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -90,6 +90,10 @@ class ProviderSpec(BaseModel): default_factory=list, description="Higher-level API surfaces may depend on other providers to provide their functionality", ) + deprecation_warning: Optional[str] = Field( + default=None, + description="If this provider is deprecated, specify the warning message here", + ) # used internally by the resolver; this is a hack for now deps__: List[str] = Field(default_factory=list) diff --git a/llama_stack/providers/inline/meta_reference/agents/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/agent_instance.py rename to llama_stack/providers/inline/agents/meta_reference/agent_instance.py diff --git a/llama_stack/providers/inline/meta_reference/agents/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/agents.py rename to llama_stack/providers/inline/agents/meta_reference/agents.py diff --git a/llama_stack/providers/inline/meta_reference/agents/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/agents/config.py rename to llama_stack/providers/inline/agents/meta_reference/config.py index 2770ed13c..8ade558c3 100644 --- a/llama_stack/providers/inline/meta_reference/agents/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -4,10 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field - from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from pydantic import BaseModel, Field class MetaReferenceAgentsImplConfig(BaseModel): diff --git a/llama_stack/providers/inline/meta_reference/agents/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/agents/persistence.py rename to llama_stack/providers/inline/agents/meta_reference/persistence.py index 37ac75d6a..36ae9b367 100644 --- a/llama_stack/providers/inline/meta_reference/agents/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -11,9 +11,8 @@ from datetime import datetime from typing import List, Optional from llama_stack.apis.agents import * # noqa: F403 -from pydantic import BaseModel - from llama_stack.providers.utils.kvstore import KVStore +from pydantic import BaseModel class AgentSessionInfo(BaseModel): diff --git a/llama_stack/providers/inline/meta_reference/agents/rag/__init__.py b/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/rag/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/rag/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py rename to llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py index b668dc0d6..3b303f5bd 100644 --- a/llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py +++ b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py @@ -10,14 +10,13 @@ from jinja2 import Template from llama_models.llama3.api import * # noqa: F403 -from termcolor import cprint # noqa: F401 - from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) +from termcolor import cprint # noqa: F401 from llama_stack.apis.inference import * # noqa: F403 diff --git a/llama_stack/providers/inline/meta_reference/agents/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/safety.py rename to llama_stack/providers/inline/agents/meta_reference/safety.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tests/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tests/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/tests/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tests/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tests/code_execution.py rename to llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tests/test_chat_agent.py rename to llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/tools/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/base.py b/llama_stack/providers/inline/agents/meta_reference/tools/base.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/base.py rename to llama_stack/providers/inline/agents/meta_reference/tools/base.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/builtin.py b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/builtin.py rename to llama_stack/providers/inline/agents/meta_reference/tools/builtin.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_env_prefix.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_execution.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/utils.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/utils.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/safety.py b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py similarity index 93% rename from llama_stack/providers/inline/meta_reference/agents/tools/safety.py rename to llama_stack/providers/inline/agents/meta_reference/tools/safety.py index 72530f0e6..1ffc99edd 100644 --- a/llama_stack/providers/inline/meta_reference/agents/tools/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py @@ -9,8 +9,7 @@ from typing import List from llama_stack.apis.inference import Message from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin - +from ..safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_stack/providers/inline/meta_reference/inference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/inference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/inference/config.py rename to llama_stack/providers/inline/inference/meta_reference/config.py index 48cba645b..6ecba22b0 100644 --- a/llama_stack/providers/inline/meta_reference/inference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403 from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F401, F403 -from pydantic import BaseModel, Field, field_validator - from llama_stack.providers.utils.inference import supported_inference_models +from pydantic import BaseModel, Field, field_validator class MetaReferenceInferenceConfig(BaseModel): diff --git a/llama_stack/providers/inline/meta_reference/inference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/inference/generation.py rename to llama_stack/providers/inline/inference/meta_reference/generation.py index 2f296c7c2..8d6a14fc9 100644 --- a/llama_stack/providers/inline/meta_reference/inference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -35,13 +35,12 @@ from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 -from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( augment_content_with_response_format_prompt, chat_completion_request_to_messages, ) +from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from .config import ( Fp8QuantizationConfig, diff --git a/llama_stack/providers/inline/meta_reference/inference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/inference.py rename to llama_stack/providers/inline/inference/meta_reference/inference.py diff --git a/llama_stack/providers/inline/meta_reference/inference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/model_parallel.py rename to llama_stack/providers/inline/inference/meta_reference/model_parallel.py diff --git a/llama_stack/providers/inline/meta_reference/inference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/parallel_utils.py rename to llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 62eeefaac..470b6b1ca 100644 --- a/llama_stack/providers/inline/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_src_rank, ) +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig from typing_extensions import Annotated -from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest - from .generation import TokenResult diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/__init__.py b/llama_stack/providers/inline/inference/meta_reference/quantization/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/fp8_impls.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/fp8_impls.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/fp8_txest_disabled.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/fp8_txest_disabled.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/hadamard_utils.py b/llama_stack/providers/inline/inference/meta_reference/quantization/hadamard_utils.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/hadamard_utils.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/hadamard_utils.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/inference/quantization/loader.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 3492ab043..286224931 100644 --- a/llama_stack/providers/inline/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model + +from llama_stack.apis.inference import QuantizationType + from termcolor import cprint from torch import nn, Tensor from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear -from llama_stack.apis.inference import QuantizationType - -from llama_stack.providers.inline.meta_reference.inference.config import ( - MetaReferenceQuantizedInferenceConfig, -) +from ..config import MetaReferenceQuantizedInferenceConfig def swiglu_wrapper( diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/__init__.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/build_conda.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/build_conda.sh similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/build_conda.sh rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/build_conda.sh diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/quantize_checkpoint.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh diff --git a/llama_stack/providers/inline/vllm/__init__.py b/llama_stack/providers/inline/inference/vllm/__init__.py similarity index 100% rename from llama_stack/providers/inline/vllm/__init__.py rename to llama_stack/providers/inline/inference/vllm/__init__.py diff --git a/llama_stack/providers/inline/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py similarity index 100% rename from llama_stack/providers/inline/vllm/config.py rename to llama_stack/providers/inline/inference/vllm/config.py index a7469ebde..22b439f77 100644 --- a/llama_stack/providers/inline/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -5,9 +5,9 @@ # the root directory of this source tree. from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field, field_validator from llama_stack.providers.utils.inference import supported_inference_models +from pydantic import BaseModel, Field, field_validator @json_schema_type diff --git a/llama_stack/providers/inline/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py similarity index 100% rename from llama_stack/providers/inline/vllm/vllm.py rename to llama_stack/providers/inline/inference/vllm/vllm.py diff --git a/llama_stack/providers/inline/meta_reference/memory/__init__.py b/llama_stack/providers/inline/memory/faiss/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/memory/__init__.py rename to llama_stack/providers/inline/memory/faiss/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/memory/config.py b/llama_stack/providers/inline/memory/faiss/config.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/memory/config.py rename to llama_stack/providers/inline/memory/faiss/config.py index 41970b05f..fd26272ae 100644 --- a/llama_stack/providers/inline/meta_reference/memory/config.py +++ b/llama_stack/providers/inline/memory/faiss/config.py @@ -5,13 +5,13 @@ # the root directory of this source tree. from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, ) +from pydantic import BaseModel @json_schema_type diff --git a/llama_stack/providers/inline/meta_reference/memory/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/memory/faiss.py rename to llama_stack/providers/inline/memory/faiss/faiss.py index 4bd5fd5a7..5726d6f87 100644 --- a/llama_stack/providers/inline/meta_reference/memory/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -8,10 +8,11 @@ import logging from typing import Any, Dict, List, Optional -import faiss import numpy as np from numpy.typing import NDArray +import faiss + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 38c5869c2..a9a1978e9 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -7,11 +7,17 @@ from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 from .....apis.common.job_types import Job -from .....apis.eval.eval import BenchmarkEvalTaskConfig +from .....apis.eval.eval import ( + AppEvalTaskConfig, + BenchmarkEvalTaskConfig, + Eval, + EvalTaskConfig, + EvaluateResponse, + JobStatus, +) from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring @@ -19,6 +25,10 @@ from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from .config import MetaReferenceEvalConfig + +# NOTE: this is the default eval task identifier for app eval +# it is used to make the router work for all app evals +# For app eval using other eval providers, the eval task identifier will be different DEFAULT_EVAL_TASK_IDENTIFIER = "meta-reference::app_eval" @@ -88,21 +98,21 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) - async def run_benchmark_eval( + async def run_benchmark( self, benchmark_id: str, - eval_task_config: BenchmarkEvalTaskConfig, + benchmark_config: BenchmarkEvalTaskConfig, ) -> Job: raise NotImplementedError("Benchmark eval is not implemented yet") async def run_eval( self, - eval_task_def: EvalTaskDef, - eval_task_config: EvalTaskConfig, + task: EvalTaskDef, + task_config: AppEvalTaskConfig, ) -> Job: - dataset_id = eval_task_def.dataset_id - candidate = eval_task_config.eval_candidate - scoring_functions = eval_task_def.scoring_functions + dataset_id = task.dataset_id + candidate = task_config.eval_candidate + scoring_functions = task.scoring_functions await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( @@ -112,7 +122,7 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): res = await self.evaluate_rows( input_rows=all_rows.rows, scoring_functions=scoring_functions, - eval_task_config=eval_task_config, + task_config=task_config, ) # TODO: currently needs to wait for generation before returning @@ -125,9 +135,10 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): self, input_rows: List[Dict[str, Any]], scoring_functions: List[str], - eval_task_config: EvalTaskConfig, + task_config: EvalTaskConfig, + eval_task_id: Optional[str] = None, ) -> EvaluateResponse: - candidate = eval_task_config.eval_candidate + candidate = task_config.eval_candidate if candidate.type == "agent": raise NotImplementedError( "Evaluation with generation has not been implemented for agents" @@ -179,23 +190,33 @@ class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): for input_r, generated_r in zip(input_rows, generations) ] + if task_config.type == "app" and task_config.scoring_params is not None: + scoring_functions_dict = { + scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None) + for scoring_fn_id in scoring_functions + } + else: + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } + score_response = await self.scoring_api.score( - input_rows=score_input_rows, scoring_functions=scoring_functions + input_rows=score_input_rows, scoring_functions=scoring_functions_dict ) return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, job_id: str) -> Optional[JobStatus]: + async def job_status(self, job_id: str, eval_task_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, job_id: str) -> None: + async def job_cancel(self, job_id: str, eval_task_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, job_id: str) -> EvaluateResponse: - status = await self.job_status(job_id) + async def job_result(self, job_id: str, eval_task_id: str) -> EvaluateResponse: + status = await self.job_status(job_id, eval_task_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py b/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py deleted file mode 100644 index 7b944319f..000000000 --- a/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import tempfile - -import pytest -from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef -from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig - -from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - - -class TestFaissMemoryImpl: - @pytest.fixture - def faiss_impl(self): - # Create a temporary SQLite database file - temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) - config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name)) - return FaissMemoryImpl(config) - - @pytest.mark.asyncio - async def test_initialize(self, faiss_impl): - # Test empty initialization - await faiss_impl.initialize() - assert len(faiss_impl.cache) == 0 - - # Test initialization with existing banks - bank = VectorMemoryBankDef( - identifier="test_bank", - type=MemoryBankType.vector.value, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - - # Register a bank and reinitialize to test loading - await faiss_impl.register_memory_bank(bank) - - # Create new instance to test initialization with existing data - new_impl = FaissMemoryImpl(faiss_impl.config) - await new_impl.initialize() - - assert len(new_impl.cache) == 1 - assert "test_bank" in new_impl.cache - - @pytest.mark.asyncio - async def test_register_memory_bank(self, faiss_impl): - bank = VectorMemoryBankDef( - identifier="test_bank", - type=MemoryBankType.vector.value, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - - await faiss_impl.initialize() - await faiss_impl.register_memory_bank(bank) - - assert "test_bank" in faiss_impl.cache - assert faiss_impl.cache["test_bank"].bank == bank - - # Verify persistence - new_impl = FaissMemoryImpl(faiss_impl.config) - await new_impl.initialize() - assert "test_bank" in new_impl.cache - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index 1ee617f8a..d59b8fac8 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -89,8 +89,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -101,7 +100,6 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): res = await self.score( input_rows=all_rows.rows, scoring_functions=scoring_functions, - scoring_params=scoring_params, ) if save_results_dataset: # TODO: persist and register dataset on to server for reading @@ -115,17 +113,14 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score( self, input_rows: List[Dict[str, Any]], - scoring_functions: List[str], - scoring_params: Optional[Dict[str, ScoringFnParams]] = None, + scoring_functions: Optional[Dict[str, ScoringFnParams]] = None, ) -> ScoreResponse: res = {} - for scoring_fn_id in scoring_functions: + for scoring_fn_id in scoring_functions.keys(): if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - scoring_fn_params = None - if scoring_params is not None: - scoring_fn_params = scoring_params.get(scoring_fn_id, None) + scoring_fn_params = scoring_functions.get(scoring_fn_id, None) score_results = await scoring_fn.score( input_rows, scoring_fn_id, scoring_fn_params ) diff --git a/llama_stack/providers/inline/meta_reference/safety/__init__.py b/llama_stack/providers/inline/safety/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/__init__.py rename to llama_stack/providers/inline/safety/meta_reference/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/safety/base.py b/llama_stack/providers/inline/safety/meta_reference/base.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/base.py rename to llama_stack/providers/inline/safety/meta_reference/base.py diff --git a/llama_stack/providers/inline/meta_reference/safety/config.py b/llama_stack/providers/inline/safety/meta_reference/config.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/config.py rename to llama_stack/providers/inline/safety/meta_reference/config.py diff --git a/llama_stack/providers/inline/meta_reference/safety/llama_guard.py b/llama_stack/providers/inline/safety/meta_reference/llama_guard.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/llama_guard.py rename to llama_stack/providers/inline/safety/meta_reference/llama_guard.py diff --git a/llama_stack/providers/inline/meta_reference/safety/prompt_guard.py b/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/prompt_guard.py rename to llama_stack/providers/inline/safety/meta_reference/prompt_guard.py diff --git a/llama_stack/providers/inline/meta_reference/safety/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/safety.py rename to llama_stack/providers/inline/safety/meta_reference/safety.py diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 774dde858..989b9f077 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]: "scikit-learn", ] + kvstore_dependencies(), - module="llama_stack.providers.inline.meta_reference.agents", - config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig", + module="llama_stack.providers.inline.agents.meta_reference", + config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig", api_dependencies=[ Api.inference, Api.safety, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 8a3619118..dc6fa9592 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, provider_type="meta-reference", pip_packages=META_REFERENCE_DEPS, - module="llama_stack.providers.inline.meta_reference.inference", - config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig", + module="llama_stack.providers.inline.inference.meta_reference", + config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", ), InlineProviderSpec( api=Api.inference, @@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]: "torchao==0.5.0", ] ), - module="llama_stack.providers.inline.meta_reference.inference", - config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", + module="llama_stack.providers.inline.inference.meta_reference", + config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig", + ), + InlineProviderSpec( + api=Api.inference, + provider_type="vllm", + pip_packages=[ + "vllm", + ], + module="llama_stack.providers.inline.inference.vllm", + config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", ), remote_provider_spec( api=Api.inference, @@ -117,7 +126,7 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.remote.inference.together", config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", - provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", ), ), remote_provider_spec( @@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", ), ), - InlineProviderSpec( - api=Api.inference, - provider_type="vllm", - pip_packages=[ - "vllm", - ], - module="llama_stack.providers.inline.vllm", - config_class="llama_stack.providers.inline.vllm.VLLMConfig", - ), ] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c2740017a..93ecb7c13 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]: api=Api.memory, provider_type="meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], - module="llama_stack.providers.inline.meta_reference.memory", - config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig", + module="llama_stack.providers.inline.memory.faiss", + config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + deprecation_warning="Please use the `faiss` provider instead.", + ), + InlineProviderSpec( + api=Api.memory, + provider_type="faiss", + pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], + module="llama_stack.providers.inline.memory.faiss", + config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", ), remote_provider_spec( Api.memory, diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index fdaa33192..fb5b6695a 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]: "transformers", "torch --index-url https://download.pytorch.org/whl/cpu", ], - module="llama_stack.providers.inline.meta_reference.safety", - config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig", + module="llama_stack.providers.inline.safety.meta_reference", + config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig", api_dependencies=[ Api.inference, ], @@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "codeshield", ], - module="llama_stack.providers.inline.meta_reference.codeshield", - config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig", + module="llama_stack.providers.inline.safety.meta_reference", + config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig", api_dependencies=[], ), ] diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 827bc620f..275ce99e7 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel): default="https://api.fireworks.ai/inference", description="The URL for the Fireworks server", ) - api_key: str = Field( - default="", + api_key: Optional[str] = Field( + default=None, description="The Fireworks.ai API Key", ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 0070756d8..57e851c5b 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -9,12 +9,11 @@ from typing import AsyncGenerator from fireworks.client import Fireworks from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 - +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig - FIREWORKS_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", @@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = { "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", + "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b", } -class FireworksInferenceAdapter(ModelRegistryHelper, Inference): +class FireworksInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__( self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS @@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: - return + pass async def shutdown(self) -> None: pass + def _get_client(self) -> Fireworks: + fireworks_api_key = None + if self.config.api_key is not None: + fireworks_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.fireworks_api_key: + raise ValueError( + 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": }' + ) + fireworks_api_key = provider_data.fireworks_api_key + return Fireworks(api_key=fireworks_api_key) + async def completion( self, model: str, @@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): stream=stream, logprobs=logprobs, ) - client = Fireworks(api_key=self.config.api_key) if stream: - return self._stream_completion(request, client) + return self._stream_completion(request) else: - return await self._nonstream_completion(request, client) + return await self._nonstream_completion(request) async def _nonstream_completion( - self, request: CompletionRequest, client: Fireworks + self, request: CompletionRequest ) -> CompletionResponse: params = await self._get_params(request) - r = await client.completion.acreate(**params) + r = await self._get_client().completion.acreate(**params) return process_completion_response(r, self.formatter) - async def _stream_completion( - self, request: CompletionRequest, client: Fireworks - ) -> AsyncGenerator: + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - stream = client.completion.acreate(**params) + # Wrapper for async generator similar + async def _to_async_generator(): + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() async for chunk in process_completion_stream_response(stream, self.formatter): yield chunk + def _build_options( + self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + ) -> dict: + options = get_sampling_options(sampling_params) + options.setdefault("max_tokens", 512) + + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.json_schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + options["response_format"] = { + "type": "grammar", + "grammar": fmt.bnf, + } + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return options + async def chat_completion( self, model: str, @@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): logprobs=logprobs, ) - client = Fireworks(api_key=self.config.api_key) if stream: - return self._stream_chat_completion(request, client) + return self._stream_chat_completion(request) else: - return await self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: Fireworks + self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await client.chat.completions.acreate(**params) + r = await self._get_client().chat.completions.acreate(**params) else: - r = await client.completion.acreate(**params) + r = await self._get_client().completion.acreate(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( - self, request: ChatCompletionRequest, client: Fireworks + self, request: ChatCompletionRequest ) -> AsyncGenerator: params = await self._get_params(request) - if "messages" in params: - stream = client.chat.completions.acreate(**params) - else: - stream = client.completion.acreate(**params) + async def _to_async_generator(): + if "messages" in params: + stream = await self._get_client().chat.completions.acreate(**params) + else: + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( stream, self.formatter ): @@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): input_dict["prompt"] = chat_completion_request_to_prompt( request, self.formatter ) - elif isinstance(request, CompletionRequest): + else: assert ( not media_present ), "Fireworks does not support media for Completion requests" input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) - else: - raise ValueError(f"Unknown request type {type(request)}") # Fireworks always prepends with BOS if "prompt" in input_dict: if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] - options = get_sampling_options(request.sampling_params) - options.setdefault("max_tokens", 512) - - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["response_format"] = { - "type": "json_object", - "schema": fmt.json_schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - options["response_format"] = { - "type": "grammar", - "grammar": fmt.bnf, - } - else: - raise ValueError(f"Unknown response format {fmt.type}") - return { "model": self.map_to_provider_model(request.model), **input_dict, "stream": request.stream, - **options, + **self._build_options(request.sampling_params, request.response_format), } async def embeddings( diff --git a/llama_stack/providers/remote/inference/together/__init__.py b/llama_stack/providers/remote/inference/together/__init__.py index 05ea91e58..2bbd9ed53 100644 --- a/llama_stack/providers/remote/inference/together/__init__.py +++ b/llama_stack/providers/remote/inference/together/__init__.py @@ -4,9 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from .config import TogetherImplConfig +class TogetherProviderDataValidator(BaseModel): + together_api_key: str + + async def get_adapter_impl(config: TogetherImplConfig, _deps): from .together import TogetherInferenceAdapter diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 86ecae1e9..8330e2604 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -11,7 +11,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.agents import ( +from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, ) diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 80a715bd2..1a41bb44a 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -52,7 +52,7 @@ class Testeval: response = await eval_impl.evaluate_rows( input_rows=rows.rows, scoring_functions=scoring_functions, - eval_task_config=AppEvalTaskConfig( + task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( model="Llama3.2-3B-Instruct", sampling_params=SamplingParams(), @@ -76,13 +76,13 @@ class Testeval: ] response = await eval_impl.run_eval( - eval_task_def=EvalTaskDef( + task=EvalTaskDef( # NOTE: this is needed to make the router work for all app evals identifier="meta-reference::app_eval", dataset_id="test_dataset_for_eval", scoring_functions=scoring_functions, ), - eval_task_config=AppEvalTaskConfig( + task_config=AppEvalTaskConfig( eval_candidate=ModelCandidate( model="Llama3.2-3B-Instruct", sampling_params=SamplingParams(), @@ -90,9 +90,13 @@ class Testeval: ), ) assert response.job_id == "0" - job_status = await eval_impl.job_status(response.job_id) + job_status = await eval_impl.job_status( + response.job_id, "meta-reference::app_eval" + ) assert job_status and job_status.value == "completed" - eval_response = await eval_impl.job_result(response.job_id) + eval_response = await eval_impl.job_result( + response.job_id, "meta-reference::app_eval" + ) assert eval_response is not None assert len(eval_response.generations) == 5 diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 41b9eb3cf..1698d7584 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -10,7 +10,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.inference import ( +from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index b30e0fae4..c0931b009 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -11,7 +11,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig +from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index de1829355..58859c991 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -8,7 +8,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.safety import ( +from llama_stack.providers.inline.safety.meta_reference import ( LlamaGuardShieldConfig, SafetyConfig, ) diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index 7a36a96f9..3c1b6554f 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -44,10 +44,10 @@ class TestScoring: ) assert len(rows.rows) == 3 - scoring_functions = [ - "meta-reference::llm_as_judge_8b_correctness", - "meta-reference::equality", - ] + scoring_functions = { + "meta-reference::llm_as_judge_8b_correctness": None, + "meta-reference::equality": None, + } response = await scoring_impl.score( input_rows=rows.rows, scoring_functions=scoring_functions, @@ -83,7 +83,7 @@ class TestScoring: ) assert len(rows.rows) == 3 - params = { + scoring_functions = { "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( judge_model="Llama3.1-405B-Instruct", prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", @@ -91,13 +91,9 @@ class TestScoring: ) } - scoring_functions = [ - "meta-reference::llm_as_judge_8b_correctness", - ] response = await scoring_impl.score( input_rows=rows.rows, scoring_functions=scoring_functions, - scoring_params=params, ) assert len(response.results) == len(scoring_functions) for x in scoring_functions: @@ -108,7 +104,6 @@ class TestScoring: response = await scoring_impl.score_batch( dataset_id="test_dataset", scoring_functions=scoring_functions, - scoring_params=params, ) assert len(response.results) == len(scoring_functions) for x in scoring_functions: