From 31c5fbda5e1484e5d3b241d8b60ae689530c9a7c Mon Sep 17 00:00:00 2001 From: "Yufei (Benny) Chen" <1585539+benjibc@users.noreply.github.com> Date: Thu, 7 Nov 2024 10:11:28 -0800 Subject: [PATCH 1/6] [LlamaStack][Fireworks] Update client and add unittest (#390) --- .../remote/inference/fireworks/config.py | 6 +- .../remote/inference/fireworks/fireworks.py | 115 +++++++++++------- 2 files changed, 73 insertions(+), 48 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/config.py b/llama_stack/providers/remote/inference/fireworks/config.py index 827bc620f..275ce99e7 100644 --- a/llama_stack/providers/remote/inference/fireworks/config.py +++ b/llama_stack/providers/remote/inference/fireworks/config.py @@ -4,6 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Optional + from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field @@ -14,7 +16,7 @@ class FireworksImplConfig(BaseModel): default="https://api.fireworks.ai/inference", description="The URL for the Fireworks server", ) - api_key: str = Field( - default="", + api_key: Optional[str] = Field( + default=None, description="The Fireworks.ai API Key", ) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 0070756d8..57e851c5b 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -9,12 +9,11 @@ from typing import AsyncGenerator from fireworks.client import Fireworks from llama_models.llama3.api.chat_format import ChatFormat - from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer from llama_stack.apis.inference import * # noqa: F403 - +from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, @@ -32,7 +31,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import FireworksImplConfig - FIREWORKS_SUPPORTED_MODELS = { "Llama3.1-8B-Instruct": "fireworks/llama-v3p1-8b-instruct", "Llama3.1-70B-Instruct": "fireworks/llama-v3p1-70b-instruct", @@ -41,10 +39,13 @@ FIREWORKS_SUPPORTED_MODELS = { "Llama3.2-3B-Instruct": "fireworks/llama-v3p2-3b-instruct", "Llama3.2-11B-Vision-Instruct": "fireworks/llama-v3p2-11b-vision-instruct", "Llama3.2-90B-Vision-Instruct": "fireworks/llama-v3p2-90b-vision-instruct", + "Llama-Guard-3-8B": "fireworks/llama-guard-3-8b", } -class FireworksInferenceAdapter(ModelRegistryHelper, Inference): +class FireworksInferenceAdapter( + ModelRegistryHelper, Inference, NeedsRequestProviderData +): def __init__(self, config: FireworksImplConfig) -> None: ModelRegistryHelper.__init__( self, stack_to_provider_models_map=FIREWORKS_SUPPORTED_MODELS @@ -53,11 +54,24 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self) -> None: - return + pass async def shutdown(self) -> None: pass + def _get_client(self) -> Fireworks: + fireworks_api_key = None + if self.config.api_key is not None: + fireworks_api_key = self.config.api_key + else: + provider_data = self.get_request_provider_data() + if provider_data is None or not provider_data.fireworks_api_key: + raise ValueError( + 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": }' + ) + fireworks_api_key = provider_data.fireworks_api_key + return Fireworks(api_key=fireworks_api_key) + async def completion( self, model: str, @@ -75,28 +89,53 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): stream=stream, logprobs=logprobs, ) - client = Fireworks(api_key=self.config.api_key) if stream: - return self._stream_completion(request, client) + return self._stream_completion(request) else: - return await self._nonstream_completion(request, client) + return await self._nonstream_completion(request) async def _nonstream_completion( - self, request: CompletionRequest, client: Fireworks + self, request: CompletionRequest ) -> CompletionResponse: params = await self._get_params(request) - r = await client.completion.acreate(**params) + r = await self._get_client().completion.acreate(**params) return process_completion_response(r, self.formatter) - async def _stream_completion( - self, request: CompletionRequest, client: Fireworks - ) -> AsyncGenerator: + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) - stream = client.completion.acreate(**params) + # Wrapper for async generator similar + async def _to_async_generator(): + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + + stream = _to_async_generator() async for chunk in process_completion_stream_response(stream, self.formatter): yield chunk + def _build_options( + self, sampling_params: Optional[SamplingParams], fmt: ResponseFormat + ) -> dict: + options = get_sampling_options(sampling_params) + options.setdefault("max_tokens", 512) + + if fmt: + if fmt.type == ResponseFormatType.json_schema.value: + options["response_format"] = { + "type": "json_object", + "schema": fmt.json_schema, + } + elif fmt.type == ResponseFormatType.grammar.value: + options["response_format"] = { + "type": "grammar", + "grammar": fmt.bnf, + } + else: + raise ValueError(f"Unknown response format {fmt.type}") + + return options + async def chat_completion( self, model: str, @@ -121,32 +160,35 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): logprobs=logprobs, ) - client = Fireworks(api_key=self.config.api_key) if stream: - return self._stream_chat_completion(request, client) + return self._stream_chat_completion(request) else: - return await self._nonstream_chat_completion(request, client) + return await self._nonstream_chat_completion(request) async def _nonstream_chat_completion( - self, request: ChatCompletionRequest, client: Fireworks + self, request: ChatCompletionRequest ) -> ChatCompletionResponse: params = await self._get_params(request) if "messages" in params: - r = await client.chat.completions.acreate(**params) + r = await self._get_client().chat.completions.acreate(**params) else: - r = await client.completion.acreate(**params) + r = await self._get_client().completion.acreate(**params) return process_chat_completion_response(r, self.formatter) async def _stream_chat_completion( - self, request: ChatCompletionRequest, client: Fireworks + self, request: ChatCompletionRequest ) -> AsyncGenerator: params = await self._get_params(request) - if "messages" in params: - stream = client.chat.completions.acreate(**params) - else: - stream = client.completion.acreate(**params) + async def _to_async_generator(): + if "messages" in params: + stream = await self._get_client().chat.completions.acreate(**params) + else: + stream = self._get_client().completion.create(**params) + for chunk in stream: + yield chunk + stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( stream, self.formatter ): @@ -167,41 +209,22 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference): input_dict["prompt"] = chat_completion_request_to_prompt( request, self.formatter ) - elif isinstance(request, CompletionRequest): + else: assert ( not media_present ), "Fireworks does not support media for Completion requests" input_dict["prompt"] = completion_request_to_prompt(request, self.formatter) - else: - raise ValueError(f"Unknown request type {type(request)}") # Fireworks always prepends with BOS if "prompt" in input_dict: if input_dict["prompt"].startswith("<|begin_of_text|>"): input_dict["prompt"] = input_dict["prompt"][len("<|begin_of_text|>") :] - options = get_sampling_options(request.sampling_params) - options.setdefault("max_tokens", 512) - - if fmt := request.response_format: - if fmt.type == ResponseFormatType.json_schema.value: - options["response_format"] = { - "type": "json_object", - "schema": fmt.json_schema, - } - elif fmt.type == ResponseFormatType.grammar.value: - options["response_format"] = { - "type": "grammar", - "grammar": fmt.bnf, - } - else: - raise ValueError(f"Unknown response format {fmt.type}") - return { "model": self.map_to_provider_model(request.model), **input_dict, "stream": request.stream, - **options, + **self._build_options(request.sampling_params, request.response_format), } async def embeddings( From 36e2538eb0eacb01cc591d0f9ecd019c26ff8f62 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 11:31:53 -0800 Subject: [PATCH 2/6] fix together inference validator (#393) --- llama_stack/providers/registry/inference.py | 2 +- llama_stack/providers/remote/inference/together/__init__.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 8a3619118..18fe8274e 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -117,7 +117,7 @@ def available_providers() -> List[ProviderSpec]: ], module="llama_stack.providers.remote.inference.together", config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", - provider_data_validator="llama_stack.providers.remote.safety.together.TogetherProviderDataValidator", + provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", ), ), remote_provider_spec( diff --git a/llama_stack/providers/remote/inference/together/__init__.py b/llama_stack/providers/remote/inference/together/__init__.py index 05ea91e58..2bbd9ed53 100644 --- a/llama_stack/providers/remote/inference/together/__init__.py +++ b/llama_stack/providers/remote/inference/together/__init__.py @@ -4,9 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from pydantic import BaseModel + from .config import TogetherImplConfig +class TogetherProviderDataValidator(BaseModel): + together_api_key: str + + async def get_adapter_impl(config: TogetherImplConfig, _deps): from .together import TogetherInferenceAdapter From 694c142b89d71b2320454a9c795c461df8335ada Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 7 Nov 2024 13:04:53 -0800 Subject: [PATCH 3/6] Add provider deprecation support; change directory structure (#397) * Add provider deprecation support; change directory structure * fix a couple dangling imports * move the meta_reference safety dir also --- llama_stack/distribution/resolver.py | 8 ++ llama_stack/providers/datatypes.py | 4 + .../meta_reference}/__init__.py | 0 .../meta_reference}/agent_instance.py | 0 .../meta_reference}/agents.py | 0 .../meta_reference}/config.py | 3 +- .../meta_reference}/persistence.py | 3 +- .../meta_reference}/rag/__init__.py | 0 .../meta_reference}/rag/context_retriever.py | 3 +- .../meta_reference}/safety.py | 0 .../meta_reference}/tests/__init__.py | 0 .../meta_reference}/tests/code_execution.py | 0 .../meta_reference}/tests/test_chat_agent.py | 0 .../meta_reference}/tools/__init__.py | 0 .../meta_reference}/tools/base.py | 0 .../meta_reference}/tools/builtin.py | 0 .../tools/ipython_tool/__init__.py | 0 .../tools/ipython_tool/code_env_prefix.py | 0 .../tools/ipython_tool/code_execution.py | 0 .../ipython_tool/matplotlib_custom_backend.py | 0 .../tools/ipython_tool/utils.py | 0 .../meta_reference}/tools/safety.py | 3 +- .../meta_reference}/__init__.py | 0 .../meta_reference}/config.py | 3 +- .../meta_reference}/generation.py | 3 +- .../meta_reference}/inference.py | 0 .../meta_reference}/model_parallel.py | 0 .../meta_reference}/parallel_utils.py | 4 +- .../meta_reference}/quantization/__init__.py | 0 .../meta_reference}/quantization/fp8_impls.py | 0 .../quantization/fp8_txest_disabled.py | 0 .../quantization/hadamard_utils.py | 0 .../meta_reference}/quantization/loader.py | 9 +-- .../quantization/scripts/__init__.py | 0 .../quantization/scripts/build_conda.sh | 0 .../scripts/quantize_checkpoint.py | 0 .../scripts/run_quantize_checkpoint.sh | 0 .../inline/{ => inference}/vllm/__init__.py | 0 .../inline/{ => inference}/vllm/config.py | 2 +- .../inline/{ => inference}/vllm/vllm.py | 0 .../memory => memory/faiss}/__init__.py | 0 .../memory => memory/faiss}/config.py | 2 +- .../memory => memory/faiss}/faiss.py | 3 +- .../meta_reference/memory/tests/test_faiss.py | 73 ------------------- .../meta_reference}/__init__.py | 0 .../safety => safety/meta_reference}/base.py | 0 .../meta_reference}/config.py | 0 .../meta_reference}/llama_guard.py | 0 .../meta_reference}/prompt_guard.py | 0 .../meta_reference}/safety.py | 0 llama_stack/providers/registry/agents.py | 4 +- llama_stack/providers/registry/inference.py | 26 +++---- llama_stack/providers/registry/memory.py | 12 ++- llama_stack/providers/registry/safety.py | 8 +- .../providers/tests/agents/fixtures.py | 2 +- .../providers/tests/inference/fixtures.py | 2 +- .../providers/tests/memory/fixtures.py | 2 +- .../providers/tests/safety/fixtures.py | 2 +- 58 files changed, 61 insertions(+), 120 deletions(-) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/agent_instance.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/agents.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/config.py (99%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/persistence.py (99%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/rag/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/rag/context_retriever.py (99%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/safety.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tests/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tests/code_execution.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tests/test_chat_agent.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/base.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/builtin.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/ipython_tool/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/ipython_tool/code_env_prefix.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/ipython_tool/code_execution.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/ipython_tool/matplotlib_custom_backend.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/ipython_tool/utils.py (100%) rename llama_stack/providers/inline/{meta_reference/agents => agents/meta_reference}/tools/safety.py (93%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/config.py (99%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/generation.py (99%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/inference.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/model_parallel.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/parallel_utils.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/fp8_impls.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/fp8_txest_disabled.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/hadamard_utils.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/loader.py (99%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/scripts/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/scripts/build_conda.sh (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/scripts/quantize_checkpoint.py (100%) rename llama_stack/providers/inline/{meta_reference/inference => inference/meta_reference}/quantization/scripts/run_quantize_checkpoint.sh (100%) rename llama_stack/providers/inline/{ => inference}/vllm/__init__.py (100%) rename llama_stack/providers/inline/{ => inference}/vllm/config.py (100%) rename llama_stack/providers/inline/{ => inference}/vllm/vllm.py (100%) rename llama_stack/providers/inline/{meta_reference/memory => memory/faiss}/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/memory => memory/faiss}/config.py (100%) rename llama_stack/providers/inline/{meta_reference/memory => memory/faiss}/faiss.py (99%) delete mode 100644 llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py rename llama_stack/providers/inline/{meta_reference/safety => safety/meta_reference}/__init__.py (100%) rename llama_stack/providers/inline/{meta_reference/safety => safety/meta_reference}/base.py (100%) rename llama_stack/providers/inline/{meta_reference/safety => safety/meta_reference}/config.py (100%) rename llama_stack/providers/inline/{meta_reference/safety => safety/meta_reference}/llama_guard.py (100%) rename llama_stack/providers/inline/{meta_reference/safety => safety/meta_reference}/prompt_guard.py (100%) rename llama_stack/providers/inline/{meta_reference/safety => safety/meta_reference}/safety.py (100%) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 96b4b81e6..9b8e41561 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -8,6 +8,8 @@ import inspect from typing import Any, Dict, List, Set +from termcolor import cprint + from llama_stack.providers.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 @@ -97,6 +99,12 @@ async def resolve_impls( ) p = provider_registry[api][provider.provider_type] + if p.deprecation_warning: + cprint( + f"Provider `{provider.provider_type}` for API `{api}` is deprecated and will be removed in a future release: {p.deprecation_warning}", + "red", + attrs=["bold"], + ) p.deps__ = [a.value for a in p.api_dependencies] spec = ProviderWithSpec( spec=p, diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 919507d11..59c5a38fa 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -82,6 +82,10 @@ class ProviderSpec(BaseModel): default_factory=list, description="Higher-level API surfaces may depend on other providers to provide their functionality", ) + deprecation_warning: Optional[str] = Field( + default=None, + description="If this provider is deprecated, specify the warning message here", + ) # used internally by the resolver; this is a hack for now deps__: List[str] = Field(default_factory=list) diff --git a/llama_stack/providers/inline/meta_reference/agents/__init__.py b/llama_stack/providers/inline/agents/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/agent_instance.py rename to llama_stack/providers/inline/agents/meta_reference/agent_instance.py diff --git a/llama_stack/providers/inline/meta_reference/agents/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/agents.py rename to llama_stack/providers/inline/agents/meta_reference/agents.py diff --git a/llama_stack/providers/inline/meta_reference/agents/config.py b/llama_stack/providers/inline/agents/meta_reference/config.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/agents/config.py rename to llama_stack/providers/inline/agents/meta_reference/config.py index 2770ed13c..8ade558c3 100644 --- a/llama_stack/providers/inline/meta_reference/agents/config.py +++ b/llama_stack/providers/inline/agents/meta_reference/config.py @@ -4,10 +4,9 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel, Field - from llama_stack.providers.utils.kvstore import KVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig +from pydantic import BaseModel, Field class MetaReferenceAgentsImplConfig(BaseModel): diff --git a/llama_stack/providers/inline/meta_reference/agents/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/agents/persistence.py rename to llama_stack/providers/inline/agents/meta_reference/persistence.py index 37ac75d6a..36ae9b367 100644 --- a/llama_stack/providers/inline/meta_reference/agents/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -11,9 +11,8 @@ from datetime import datetime from typing import List, Optional from llama_stack.apis.agents import * # noqa: F403 -from pydantic import BaseModel - from llama_stack.providers.utils.kvstore import KVStore +from pydantic import BaseModel class AgentSessionInfo(BaseModel): diff --git a/llama_stack/providers/inline/meta_reference/agents/rag/__init__.py b/llama_stack/providers/inline/agents/meta_reference/rag/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/rag/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/rag/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py rename to llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py index b668dc0d6..3b303f5bd 100644 --- a/llama_stack/providers/inline/meta_reference/agents/rag/context_retriever.py +++ b/llama_stack/providers/inline/agents/meta_reference/rag/context_retriever.py @@ -10,14 +10,13 @@ from jinja2 import Template from llama_models.llama3.api import * # noqa: F403 -from termcolor import cprint # noqa: F401 - from llama_stack.apis.agents import ( DefaultMemoryQueryGeneratorConfig, LLMMemoryQueryGeneratorConfig, MemoryQueryGenerator, MemoryQueryGeneratorConfig, ) +from termcolor import cprint # noqa: F401 from llama_stack.apis.inference import * # noqa: F403 diff --git a/llama_stack/providers/inline/meta_reference/agents/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/safety.py rename to llama_stack/providers/inline/agents/meta_reference/safety.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tests/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tests/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tests/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/tests/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tests/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tests/code_execution.py rename to llama_stack/providers/inline/agents/meta_reference/tests/code_execution.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tests/test_chat_agent.py rename to llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/tools/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/base.py b/llama_stack/providers/inline/agents/meta_reference/tools/base.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/base.py rename to llama_stack/providers/inline/agents/meta_reference/tools/base.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/builtin.py b/llama_stack/providers/inline/agents/meta_reference/tools/builtin.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/builtin.py rename to llama_stack/providers/inline/agents/meta_reference/tools/builtin.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/__init__.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/__init__.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_env_prefix.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_env_prefix.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_env_prefix.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_execution.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/code_execution.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/code_execution.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/matplotlib_custom_backend.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/matplotlib_custom_backend.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/utils.py b/llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/agents/tools/ipython_tool/utils.py rename to llama_stack/providers/inline/agents/meta_reference/tools/ipython_tool/utils.py diff --git a/llama_stack/providers/inline/meta_reference/agents/tools/safety.py b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py similarity index 93% rename from llama_stack/providers/inline/meta_reference/agents/tools/safety.py rename to llama_stack/providers/inline/agents/meta_reference/tools/safety.py index 72530f0e6..1ffc99edd 100644 --- a/llama_stack/providers/inline/meta_reference/agents/tools/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/tools/safety.py @@ -9,8 +9,7 @@ from typing import List from llama_stack.apis.inference import Message from llama_stack.apis.safety import * # noqa: F403 -from llama_stack.providers.inline.meta_reference.agents.safety import ShieldRunnerMixin - +from ..safety import ShieldRunnerMixin from .builtin import BaseTool diff --git a/llama_stack/providers/inline/meta_reference/inference/__init__.py b/llama_stack/providers/inline/inference/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/inference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/inference/config.py rename to llama_stack/providers/inline/inference/meta_reference/config.py index 48cba645b..6ecba22b0 100644 --- a/llama_stack/providers/inline/meta_reference/inference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -10,9 +10,8 @@ from llama_models.datatypes import * # noqa: F403 from llama_models.sku_list import resolve_model from llama_stack.apis.inference import * # noqa: F401, F403 -from pydantic import BaseModel, Field, field_validator - from llama_stack.providers.utils.inference import supported_inference_models +from pydantic import BaseModel, Field, field_validator class MetaReferenceInferenceConfig(BaseModel): diff --git a/llama_stack/providers/inline/meta_reference/inference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/inference/generation.py rename to llama_stack/providers/inline/inference/meta_reference/generation.py index 2f296c7c2..8d6a14fc9 100644 --- a/llama_stack/providers/inline/meta_reference/inference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -35,13 +35,12 @@ from termcolor import cprint from llama_stack.apis.inference import * # noqa: F403 -from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( augment_content_with_response_format_prompt, chat_completion_request_to_messages, ) +from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from .config import ( Fp8QuantizationConfig, diff --git a/llama_stack/providers/inline/meta_reference/inference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/inference.py rename to llama_stack/providers/inline/inference/meta_reference/inference.py diff --git a/llama_stack/providers/inline/meta_reference/inference/model_parallel.py b/llama_stack/providers/inline/inference/meta_reference/model_parallel.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/model_parallel.py rename to llama_stack/providers/inline/inference/meta_reference/model_parallel.py diff --git a/llama_stack/providers/inline/meta_reference/inference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/parallel_utils.py rename to llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 62eeefaac..470b6b1ca 100644 --- a/llama_stack/providers/inline/meta_reference/inference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -28,13 +28,13 @@ from fairscale.nn.model_parallel.initialize import ( get_model_parallel_src_rank, ) +from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest + from pydantic import BaseModel, Field from torch.distributed.launcher.api import elastic_launch, LaunchConfig from typing_extensions import Annotated -from llama_stack.apis.inference import ChatCompletionRequest, CompletionRequest - from .generation import TokenResult diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/__init__.py b/llama_stack/providers/inline/inference/meta_reference/quantization/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/fp8_impls.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/fp8_impls.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/fp8_txest_disabled.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/fp8_txest_disabled.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/hadamard_utils.py b/llama_stack/providers/inline/inference/meta_reference/quantization/hadamard_utils.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/hadamard_utils.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/hadamard_utils.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/inference/quantization/loader.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 3492ab043..286224931 100644 --- a/llama_stack/providers/inline/meta_reference/inference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -20,16 +20,15 @@ from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model + +from llama_stack.apis.inference import QuantizationType + from termcolor import cprint from torch import nn, Tensor from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear -from llama_stack.apis.inference import QuantizationType - -from llama_stack.providers.inline.meta_reference.inference.config import ( - MetaReferenceQuantizedInferenceConfig, -) +from ..config import MetaReferenceQuantizedInferenceConfig def swiglu_wrapper( diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/__init__.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/__init__.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/build_conda.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/build_conda.sh similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/build_conda.sh rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/build_conda.sh diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/quantize_checkpoint.py rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py diff --git a/llama_stack/providers/inline/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh similarity index 100% rename from llama_stack/providers/inline/meta_reference/inference/quantization/scripts/run_quantize_checkpoint.sh rename to llama_stack/providers/inline/inference/meta_reference/quantization/scripts/run_quantize_checkpoint.sh diff --git a/llama_stack/providers/inline/vllm/__init__.py b/llama_stack/providers/inline/inference/vllm/__init__.py similarity index 100% rename from llama_stack/providers/inline/vllm/__init__.py rename to llama_stack/providers/inline/inference/vllm/__init__.py diff --git a/llama_stack/providers/inline/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py similarity index 100% rename from llama_stack/providers/inline/vllm/config.py rename to llama_stack/providers/inline/inference/vllm/config.py index a7469ebde..22b439f77 100644 --- a/llama_stack/providers/inline/vllm/config.py +++ b/llama_stack/providers/inline/inference/vllm/config.py @@ -5,9 +5,9 @@ # the root directory of this source tree. from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field, field_validator from llama_stack.providers.utils.inference import supported_inference_models +from pydantic import BaseModel, Field, field_validator @json_schema_type diff --git a/llama_stack/providers/inline/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py similarity index 100% rename from llama_stack/providers/inline/vllm/vllm.py rename to llama_stack/providers/inline/inference/vllm/vllm.py diff --git a/llama_stack/providers/inline/meta_reference/memory/__init__.py b/llama_stack/providers/inline/memory/faiss/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/memory/__init__.py rename to llama_stack/providers/inline/memory/faiss/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/memory/config.py b/llama_stack/providers/inline/memory/faiss/config.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/memory/config.py rename to llama_stack/providers/inline/memory/faiss/config.py index 41970b05f..fd26272ae 100644 --- a/llama_stack/providers/inline/meta_reference/memory/config.py +++ b/llama_stack/providers/inline/memory/faiss/config.py @@ -5,13 +5,13 @@ # the root directory of this source tree. from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, ) +from pydantic import BaseModel @json_schema_type diff --git a/llama_stack/providers/inline/meta_reference/memory/faiss.py b/llama_stack/providers/inline/memory/faiss/faiss.py similarity index 99% rename from llama_stack/providers/inline/meta_reference/memory/faiss.py rename to llama_stack/providers/inline/memory/faiss/faiss.py index 4bd5fd5a7..5726d6f87 100644 --- a/llama_stack/providers/inline/meta_reference/memory/faiss.py +++ b/llama_stack/providers/inline/memory/faiss/faiss.py @@ -8,10 +8,11 @@ import logging from typing import Any, Dict, List, Optional -import faiss import numpy as np from numpy.typing import NDArray +import faiss + from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 diff --git a/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py b/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py deleted file mode 100644 index 7b944319f..000000000 --- a/llama_stack/providers/inline/meta_reference/memory/tests/test_faiss.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import tempfile - -import pytest -from llama_stack.apis.memory import MemoryBankType, VectorMemoryBankDef -from llama_stack.providers.inline.meta_reference.memory.config import FaissImplConfig - -from llama_stack.providers.inline.meta_reference.memory.faiss import FaissMemoryImpl -from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig - - -class TestFaissMemoryImpl: - @pytest.fixture - def faiss_impl(self): - # Create a temporary SQLite database file - temp_db = tempfile.NamedTemporaryFile(suffix=".db", delete=False) - config = FaissImplConfig(kvstore=SqliteKVStoreConfig(db_path=temp_db.name)) - return FaissMemoryImpl(config) - - @pytest.mark.asyncio - async def test_initialize(self, faiss_impl): - # Test empty initialization - await faiss_impl.initialize() - assert len(faiss_impl.cache) == 0 - - # Test initialization with existing banks - bank = VectorMemoryBankDef( - identifier="test_bank", - type=MemoryBankType.vector.value, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - - # Register a bank and reinitialize to test loading - await faiss_impl.register_memory_bank(bank) - - # Create new instance to test initialization with existing data - new_impl = FaissMemoryImpl(faiss_impl.config) - await new_impl.initialize() - - assert len(new_impl.cache) == 1 - assert "test_bank" in new_impl.cache - - @pytest.mark.asyncio - async def test_register_memory_bank(self, faiss_impl): - bank = VectorMemoryBankDef( - identifier="test_bank", - type=MemoryBankType.vector.value, - embedding_model="all-MiniLM-L6-v2", - chunk_size_in_tokens=512, - overlap_size_in_tokens=64, - ) - - await faiss_impl.initialize() - await faiss_impl.register_memory_bank(bank) - - assert "test_bank" in faiss_impl.cache - assert faiss_impl.cache["test_bank"].bank == bank - - # Verify persistence - new_impl = FaissMemoryImpl(faiss_impl.config) - await new_impl.initialize() - assert "test_bank" in new_impl.cache - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/llama_stack/providers/inline/meta_reference/safety/__init__.py b/llama_stack/providers/inline/safety/meta_reference/__init__.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/__init__.py rename to llama_stack/providers/inline/safety/meta_reference/__init__.py diff --git a/llama_stack/providers/inline/meta_reference/safety/base.py b/llama_stack/providers/inline/safety/meta_reference/base.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/base.py rename to llama_stack/providers/inline/safety/meta_reference/base.py diff --git a/llama_stack/providers/inline/meta_reference/safety/config.py b/llama_stack/providers/inline/safety/meta_reference/config.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/config.py rename to llama_stack/providers/inline/safety/meta_reference/config.py diff --git a/llama_stack/providers/inline/meta_reference/safety/llama_guard.py b/llama_stack/providers/inline/safety/meta_reference/llama_guard.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/llama_guard.py rename to llama_stack/providers/inline/safety/meta_reference/llama_guard.py diff --git a/llama_stack/providers/inline/meta_reference/safety/prompt_guard.py b/llama_stack/providers/inline/safety/meta_reference/prompt_guard.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/prompt_guard.py rename to llama_stack/providers/inline/safety/meta_reference/prompt_guard.py diff --git a/llama_stack/providers/inline/meta_reference/safety/safety.py b/llama_stack/providers/inline/safety/meta_reference/safety.py similarity index 100% rename from llama_stack/providers/inline/meta_reference/safety/safety.py rename to llama_stack/providers/inline/safety/meta_reference/safety.py diff --git a/llama_stack/providers/registry/agents.py b/llama_stack/providers/registry/agents.py index 774dde858..989b9f077 100644 --- a/llama_stack/providers/registry/agents.py +++ b/llama_stack/providers/registry/agents.py @@ -22,8 +22,8 @@ def available_providers() -> List[ProviderSpec]: "scikit-learn", ] + kvstore_dependencies(), - module="llama_stack.providers.inline.meta_reference.agents", - config_class="llama_stack.providers.inline.meta_reference.agents.MetaReferenceAgentsImplConfig", + module="llama_stack.providers.inline.agents.meta_reference", + config_class="llama_stack.providers.inline.agents.meta_reference.MetaReferenceAgentsImplConfig", api_dependencies=[ Api.inference, Api.safety, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 18fe8274e..dc6fa9592 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -27,8 +27,8 @@ def available_providers() -> List[ProviderSpec]: api=Api.inference, provider_type="meta-reference", pip_packages=META_REFERENCE_DEPS, - module="llama_stack.providers.inline.meta_reference.inference", - config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceInferenceConfig", + module="llama_stack.providers.inline.inference.meta_reference", + config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", ), InlineProviderSpec( api=Api.inference, @@ -40,8 +40,17 @@ def available_providers() -> List[ProviderSpec]: "torchao==0.5.0", ] ), - module="llama_stack.providers.inline.meta_reference.inference", - config_class="llama_stack.providers.inline.meta_reference.inference.MetaReferenceQuantizedInferenceConfig", + module="llama_stack.providers.inline.inference.meta_reference", + config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig", + ), + InlineProviderSpec( + api=Api.inference, + provider_type="vllm", + pip_packages=[ + "vllm", + ], + module="llama_stack.providers.inline.inference.vllm", + config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", ), remote_provider_spec( api=Api.inference, @@ -140,13 +149,4 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", ), ), - InlineProviderSpec( - api=Api.inference, - provider_type="vllm", - pip_packages=[ - "vllm", - ], - module="llama_stack.providers.inline.vllm", - config_class="llama_stack.providers.inline.vllm.VLLMConfig", - ), ] diff --git a/llama_stack/providers/registry/memory.py b/llama_stack/providers/registry/memory.py index c2740017a..93ecb7c13 100644 --- a/llama_stack/providers/registry/memory.py +++ b/llama_stack/providers/registry/memory.py @@ -36,8 +36,16 @@ def available_providers() -> List[ProviderSpec]: api=Api.memory, provider_type="meta-reference", pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], - module="llama_stack.providers.inline.meta_reference.memory", - config_class="llama_stack.providers.inline.meta_reference.memory.FaissImplConfig", + module="llama_stack.providers.inline.memory.faiss", + config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", + deprecation_warning="Please use the `faiss` provider instead.", + ), + InlineProviderSpec( + api=Api.memory, + provider_type="faiss", + pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], + module="llama_stack.providers.inline.memory.faiss", + config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", ), remote_provider_spec( Api.memory, diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index fdaa33192..fb5b6695a 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -24,8 +24,8 @@ def available_providers() -> List[ProviderSpec]: "transformers", "torch --index-url https://download.pytorch.org/whl/cpu", ], - module="llama_stack.providers.inline.meta_reference.safety", - config_class="llama_stack.providers.inline.meta_reference.safety.SafetyConfig", + module="llama_stack.providers.inline.safety.meta_reference", + config_class="llama_stack.providers.inline.safety.meta_reference.SafetyConfig", api_dependencies=[ Api.inference, ], @@ -54,8 +54,8 @@ def available_providers() -> List[ProviderSpec]: pip_packages=[ "codeshield", ], - module="llama_stack.providers.inline.meta_reference.codeshield", - config_class="llama_stack.providers.inline.meta_reference.codeshield.CodeShieldConfig", + module="llama_stack.providers.inline.safety.meta_reference", + config_class="llama_stack.providers.inline.safety.meta_reference.CodeShieldConfig", api_dependencies=[], ), ] diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index 86ecae1e9..8330e2604 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -11,7 +11,7 @@ import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.agents import ( +from llama_stack.providers.inline.agents.meta_reference import ( MetaReferenceAgentsImplConfig, ) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 9db70888e..5b047549b 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -10,7 +10,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.inference import ( +from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index b30e0fae4..c0931b009 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -11,7 +11,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.memory import FaissImplConfig +from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.remote.memory.pgvector import PGVectorConfig from llama_stack.providers.remote.memory.weaviate import WeaviateConfig diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index de1829355..58859c991 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -8,7 +8,7 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.meta_reference.safety import ( +from llama_stack.providers.inline.safety.meta_reference import ( LlamaGuardShieldConfig, SafetyConfig, ) From 345ae07317e96ec8554a404cc4dd60b22a418467 Mon Sep 17 00:00:00 2001 From: Dalton Flanagan <6599399+dltn@users.noreply.github.com> Date: Thu, 7 Nov 2024 16:13:19 -0500 Subject: [PATCH 4/6] Factor out create_dist_registry (#398) --- llama_stack/distribution/server/server.py | 19 ++------------ llama_stack/distribution/store/registry.py | 30 ++++++++++++++++++++-- 2 files changed, 30 insertions(+), 19 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 16c0fd0e0..143813780 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -31,7 +31,7 @@ from llama_stack.distribution.distribution import ( get_provider_registry, ) -from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.distribution.store.registry import create_dist_registry from llama_stack.providers.utils.telemetry.tracing import ( end_trace, @@ -42,8 +42,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls -from llama_stack.distribution.store import CachedDiskDistributionRegistry -from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig from .endpoints import get_all_api_endpoints @@ -281,21 +279,8 @@ def main( config = StackRunConfig(**yaml.safe_load(fp)) app = FastAPI() - # instantiate kvstore for storing and retrieving distribution metadata - if config.metadata_store: - dist_kvstore = asyncio.run(kvstore_impl(config.metadata_store)) - else: - dist_kvstore = asyncio.run( - kvstore_impl( - SqliteKVStoreConfig( - db_path=( - DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" - ).as_posix() - ) - ) - ) - dist_registry = CachedDiskDistributionRegistry(dist_kvstore) + dist_registry, dist_kvstore = asyncio.run(create_dist_registry(config)) impls = asyncio.run(resolve_impls(config, get_provider_registry(), dist_registry)) if Api.telemetry in impls: diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 994fb475c..897bb90d0 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -9,9 +9,17 @@ from typing import Dict, List, Protocol import pydantic -from llama_stack.distribution.datatypes import RoutableObjectWithProvider +from llama_stack.distribution.datatypes import ( + RoutableObjectWithProvider, + StackRunConfig, +) +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR -from llama_stack.providers.utils.kvstore import KVStore +from llama_stack.providers.utils.kvstore import ( + KVStore, + kvstore_impl, + SqliteKVStoreConfig, +) class DistributionRegistry(Protocol): @@ -133,3 +141,21 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): self.cache[obj.identifier].append(obj) return success + + +async def create_dist_registry( + config: StackRunConfig, +) -> tuple[CachedDiskDistributionRegistry, KVStore]: + # instantiate kvstore for storing and retrieving distribution metadata + if config.metadata_store: + dist_kvstore = await kvstore_impl(config.metadata_store) + else: + dist_kvstore = await kvstore_impl( + SqliteKVStoreConfig( + db_path=( + DISTRIBS_BASE_DIR / config.image_name / "kvstore.db" + ).as_posix() + ) + ) + + return CachedDiskDistributionRegistry(dist_kvstore), dist_kvstore From 8350f2df4c530c16b53ffaa7a7ba0a677df74be8 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 19:16:38 -0800 Subject: [PATCH 5/6] [docs] refactor remote-hosted distro (#402) * move docs * docs --- CONTRIBUTING.md | 1 + .../remote_hosted_distro/index.md | 45 +++++++++++++++---- .../bedrock.md | 0 .../fireworks.md | 0 .../distributions/self_hosted_distro/index.md | 7 +++ .../together.md | 0 6 files changed, 44 insertions(+), 9 deletions(-) rename docs/source/getting_started/distributions/{remote_hosted_distro => self_hosted_distro}/bedrock.md (100%) rename docs/source/getting_started/distributions/{remote_hosted_distro => self_hosted_distro}/fireworks.md (100%) rename docs/source/getting_started/distributions/{remote_hosted_distro => self_hosted_distro}/together.md (100%) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index ab9c4d82e..7e05c683a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -22,6 +22,7 @@ pip install -r requirements.txt pip install sphinx-autobuild # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. +make html sphinx-autobuild source build/html ``` diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/index.md b/docs/source/getting_started/distributions/remote_hosted_distro/index.md index 719f2f301..76d5fdf27 100644 --- a/docs/source/getting_started/distributions/remote_hosted_distro/index.md +++ b/docs/source/getting_started/distributions/remote_hosted_distro/index.md @@ -1,15 +1,42 @@ # Remote-Hosted Distribution -Remote Hosted distributions are distributions connecting to remote hosted services through Llama Stack server. Inference is done through remote providers. These are useful if you have an API key for a remote inference provider like Fireworks, Together, etc. +Remote-Hosted distributions are available endpoints serving Llama Stack API that you can directly connect to. -| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | -|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: | -| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference | -| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference | +| Distribution | Endpoint | Inference | Agents | Memory | Safety | Telemetry | +|-------------|----------|-----------|---------|---------|---------|------------| +| Together | [https://llama-stack.together.ai](https://llama-stack.together.ai) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference | +| Fireworks | [https://llamastack-preview.fireworks.ai](https://llamastack-preview.fireworks.ai) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference | -```{toctree} -:maxdepth: 1 +## Connecting to Remote-Hosted Distributions -fireworks -together +You can use `llama-stack-client` to interact with these endpoints. For example, to list the available models served by the Fireworks endpoint: + +```bash +$ pip install llama-stack-client +$ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.ai +$ llama-stack-client models list ``` + +You will see outputs: +``` +$ llama-stack-client models list ++------------------------------+------------------------------+---------------+------------+ +| identifier | llama_model | provider_id | metadata | ++==============================+==============================+===============+============+ +| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} | ++------------------------------+------------------------------+---------------+------------+ +``` + +Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack. diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md b/docs/source/getting_started/distributions/self_hosted_distro/bedrock.md similarity index 100% rename from docs/source/getting_started/distributions/remote_hosted_distro/bedrock.md rename to docs/source/getting_started/distributions/self_hosted_distro/bedrock.md diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/fireworks.md b/docs/source/getting_started/distributions/self_hosted_distro/fireworks.md similarity index 100% rename from docs/source/getting_started/distributions/remote_hosted_distro/fireworks.md rename to docs/source/getting_started/distributions/self_hosted_distro/fireworks.md diff --git a/docs/source/getting_started/distributions/self_hosted_distro/index.md b/docs/source/getting_started/distributions/self_hosted_distro/index.md index a2f3876ec..ed6ab5d7f 100644 --- a/docs/source/getting_started/distributions/self_hosted_distro/index.md +++ b/docs/source/getting_started/distributions/self_hosted_distro/index.md @@ -8,6 +8,10 @@ We offer deployable distributions where you can host your own Llama Stack server | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | meta-reference | meta-reference | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | +| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference | +| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference | +| Bedrock | [llamastack/distribution-bedrock](https://hub.docker.com/repository/docker/llamastack/distribution-bedrock/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/bedrock.html) | remote::bedrock | meta-reference | remote::weaviate | meta-reference | meta-reference | + ```{toctree} :maxdepth: 1 @@ -17,4 +21,7 @@ meta-reference-quantized-gpu ollama tgi dell-tgi +together +fireworks +bedrock ``` diff --git a/docs/source/getting_started/distributions/remote_hosted_distro/together.md b/docs/source/getting_started/distributions/self_hosted_distro/together.md similarity index 100% rename from docs/source/getting_started/distributions/remote_hosted_distro/together.md rename to docs/source/getting_started/distributions/self_hosted_distro/together.md From 6192bf43a4ce6ae3ac03f7fd0eea22c261f10e4d Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 7 Nov 2024 21:24:12 -0800 Subject: [PATCH 6/6] [Evals API][10/n] API updates for EvalTaskDef + new test migration (#379) * wip * scoring fn api * eval api * eval task * evaluate api update * pre commit * unwrap context -> config * config field doc * typo * naming fix * separate benchmark / app eval * api name * rename * wip tests * wip * datasetio test * delete unused * fixture * scoring resolve * fix scoring register * scoring test pass * score batch * scoring fix * fix eval * test eval works * remove type ignore * api refactor * add default task_eval_id for routing * add eval_id for jobs * remove type ignore * only keep 1 run_eval * fix optional * register task required * register task required * delete old tests * delete old tests * fixture return impl --- llama_stack/apis/eval/eval.py | 46 +++- llama_stack/apis/eval_tasks/__init__.py | 7 + llama_stack/apis/eval_tasks/eval_tasks.py | 43 ++++ llama_stack/apis/scoring/scoring.py | 6 +- .../scoring_functions/scoring_functions.py | 65 ++++-- llama_stack/distribution/distribution.py | 4 + llama_stack/distribution/resolver.py | 3 + llama_stack/distribution/routers/__init__.py | 4 + llama_stack/distribution/routers/routers.py | 83 ++++++- .../distribution/routers/routing_tables.py | 26 ++- llama_stack/providers/datatypes.py | 8 + .../inline/meta_reference/eval/eval.py | 64 ++++-- .../inline/meta_reference/scoring/scoring.py | 19 +- .../scoring/scoring_fn/base_scoring_fn.py | 8 +- .../scoring/scoring_fn/equality_scoring_fn.py | 1 + .../fn_defs/llm_as_judge_8b_correctness.py | 8 +- .../scoring_fn/llm_as_judge_scoring_fn.py | 22 +- .../scoring_fn/subset_of_scoring_fn.py | 1 + llama_stack/providers/tests/conftest.py | 3 + .../providers/tests/datasetio/conftest.py | 29 +++ .../providers/tests/datasetio/fixtures.py | 48 ++++ .../datasetio/provider_config_example.yaml | 4 - .../tests/datasetio/test_datasetio.py | 126 ++++------ llama_stack/providers/tests/eval/conftest.py | 72 ++++++ llama_stack/providers/tests/eval/fixtures.py | 55 +++++ .../tests/eval/provider_config_example.yaml | 22 -- llama_stack/providers/tests/eval/test_eval.py | 167 +++++++++----- .../providers/tests/inference/fixtures.py | 1 + .../providers/tests/scoring/conftest.py | 68 ++++++ .../providers/tests/scoring/fixtures.py | 60 +++++ .../scoring/provider_config_example.yaml | 17 -- .../providers/tests/scoring/test_scoring.py | 215 +++++++----------- 32 files changed, 916 insertions(+), 389 deletions(-) create mode 100644 llama_stack/apis/eval_tasks/__init__.py create mode 100644 llama_stack/apis/eval_tasks/eval_tasks.py create mode 100644 llama_stack/providers/tests/datasetio/conftest.py create mode 100644 llama_stack/providers/tests/datasetio/fixtures.py delete mode 100644 llama_stack/providers/tests/datasetio/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/eval/conftest.py create mode 100644 llama_stack/providers/tests/eval/fixtures.py delete mode 100644 llama_stack/providers/tests/eval/provider_config_example.yaml create mode 100644 llama_stack/providers/tests/scoring/conftest.py create mode 100644 llama_stack/providers/tests/scoring/fixtures.py delete mode 100644 llama_stack/providers/tests/scoring/provider_config_example.yaml diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index 51f49da15..50fb922fe 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -14,6 +14,7 @@ from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.apis.agents import AgentConfig from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 @json_schema_type @@ -35,36 +36,57 @@ EvalCandidate = Annotated[ ] +@json_schema_type +class BenchmarkEvalTaskConfig(BaseModel): + type: Literal["benchmark"] = "benchmark" + eval_candidate: EvalCandidate + + +@json_schema_type +class AppEvalTaskConfig(BaseModel): + type: Literal["app"] = "app" + eval_candidate: EvalCandidate + scoring_params: Dict[str, ScoringFnParams] = Field( + description="Map between scoring function id and parameters for each scoring function you want to run", + default_factory=dict, + ) + # we could optinally add any specific dataset config here + + +EvalTaskConfig = Annotated[ + Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type") +] + + @json_schema_type class EvaluateResponse(BaseModel): generations: List[Dict[str, Any]] - # each key in the dict is a scoring function name scores: Dict[str, ScoringResult] class Eval(Protocol): - @webmethod(route="/eval/evaluate_batch", method="POST") - async def evaluate_batch( + @webmethod(route="/eval/run_eval", method="POST") + async def run_eval( self, - dataset_id: str, - candidate: EvalCandidate, - scoring_functions: List[str], + task_id: str, + task_config: EvalTaskConfig, ) -> Job: ... - @webmethod(route="/eval/evaluate", method="POST") - async def evaluate( + @webmethod(route="/eval/evaluate_rows", method="POST") + async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], - candidate: EvalCandidate, scoring_functions: List[str], + task_config: EvalTaskConfig, ) -> EvaluateResponse: ... @webmethod(route="/eval/job/status", method="GET") - async def job_status(self, job_id: str) -> Optional[JobStatus]: ... + async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ... @webmethod(route="/eval/job/cancel", method="POST") - async def job_cancel(self, job_id: str) -> None: ... + async def job_cancel(self, task_id: str, job_id: str) -> None: ... @webmethod(route="/eval/job/result", method="GET") - async def job_result(self, job_id: str) -> EvaluateResponse: ... + async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ... diff --git a/llama_stack/apis/eval_tasks/__init__.py b/llama_stack/apis/eval_tasks/__init__.py new file mode 100644 index 000000000..7ca216706 --- /dev/null +++ b/llama_stack/apis/eval_tasks/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .eval_tasks import * # noqa: F401 F403 diff --git a/llama_stack/apis/eval_tasks/eval_tasks.py b/llama_stack/apis/eval_tasks/eval_tasks.py new file mode 100644 index 000000000..0007066aa --- /dev/null +++ b/llama_stack/apis/eval_tasks/eval_tasks.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable + +from llama_models.schema_utils import json_schema_type, webmethod + +from pydantic import BaseModel, Field + + +@json_schema_type +class EvalTaskDef(BaseModel): + identifier: str + dataset_id: str + scoring_functions: List[str] + metadata: Dict[str, Any] = Field( + default_factory=dict, + description="Metadata for this evaluation task", + ) + + +@json_schema_type +class EvalTaskDefWithProvider(EvalTaskDef): + type: Literal["eval_task"] = "eval_task" + provider_id: str = Field( + description="ID of the provider which serves this dataset", + ) + + +@runtime_checkable +class EvalTasks(Protocol): + @webmethod(route="/eval_tasks/list", method="GET") + async def list_eval_tasks(self) -> List[EvalTaskDefWithProvider]: ... + + @webmethod(route="/eval_tasks/get", method="GET") + async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: ... + + @webmethod(route="/eval_tasks/register", method="POST") + async def register_eval_task( + self, eval_task_def: EvalTaskDefWithProvider + ) -> None: ... diff --git a/llama_stack/apis/scoring/scoring.py b/llama_stack/apis/scoring/scoring.py index 1fd523dcb..c2bfdcd23 100644 --- a/llama_stack/apis/scoring/scoring.py +++ b/llama_stack/apis/scoring/scoring.py @@ -48,11 +48,13 @@ class Scoring(Protocol): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: ... @webmethod(route="/scoring/score") async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index d0a9cc597..140376242 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -4,34 +4,66 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable +from enum import Enum +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + Protocol, + runtime_checkable, + Union, +) from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field +from typing_extensions import Annotated from llama_stack.apis.common.type_system import ParamType -@json_schema_type -class Parameter(BaseModel): - name: str - type: ParamType - description: Optional[str] = None - - # Perhaps more structure can be imposed on these functions. Maybe they could be associated # with standard metrics so they can be rolled up? +@json_schema_type +class ScoringConfigType(Enum): + llm_as_judge = "llm_as_judge" + regex_parser = "regex_parser" -class LLMAsJudgeContext(BaseModel): +@json_schema_type +class LLMAsJudgeScoringFnParams(BaseModel): + type: Literal[ScoringConfigType.llm_as_judge.value] = ( + ScoringConfigType.llm_as_judge.value + ) judge_model: str prompt_template: Optional[str] = None - judge_score_regex: Optional[List[str]] = Field( - description="Regex to extract the score from the judge response", - default=None, + judge_score_regexes: Optional[List[str]] = Field( + description="Regexes to extract the answer from generated response", + default_factory=list, ) +@json_schema_type +class RegexParserScoringFnParams(BaseModel): + type: Literal[ScoringConfigType.regex_parser.value] = ( + ScoringConfigType.regex_parser.value + ) + parsing_regexes: Optional[List[str]] = Field( + description="Regex to extract the answer from generated response", + default_factory=list, + ) + + +ScoringFnParams = Annotated[ + Union[ + LLMAsJudgeScoringFnParams, + RegexParserScoringFnParams, + ], + Field(discriminator="type"), +] + + @json_schema_type class ScoringFnDef(BaseModel): identifier: str @@ -40,14 +72,13 @@ class ScoringFnDef(BaseModel): default_factory=dict, description="Any additional metadata for this definition", ) - parameters: List[Parameter] = Field( - description="List of parameters for the deterministic function", - default_factory=list, - ) return_type: ParamType = Field( description="The return type of the deterministic function", ) - context: Optional[LLMAsJudgeContext] = None + params: Optional[ScoringFnParams] = Field( + description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval", + default=None, + ) # We can optionally add information here to support packaging of code, etc. diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index 2149162a6..3fc3b2d5d 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -43,6 +43,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]: routing_table_api=Api.scoring_functions, router_api=Api.scoring, ), + AutoRoutedApiInfo( + routing_table_api=Api.eval_tasks, + router_api=Api.eval, + ), ] diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 9b8e41561..aac7ae5b6 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -17,6 +17,7 @@ from llama_stack.apis.agents import Agents from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.eval import Eval +from llama_stack.apis.eval_tasks import EvalTasks from llama_stack.apis.inference import Inference from llama_stack.apis.inspect import Inspect from llama_stack.apis.memory import Memory @@ -48,6 +49,7 @@ def api_protocol_map() -> Dict[Api, Any]: Api.scoring: Scoring, Api.scoring_functions: ScoringFunctions, Api.eval: Eval, + Api.eval_tasks: EvalTasks, } @@ -58,6 +60,7 @@ def additional_protocols_map() -> Dict[Api, Any]: Api.safety: (ShieldsProtocolPrivate, Shields), Api.datasetio: (DatasetsProtocolPrivate, Datasets), Api.scoring: (ScoringFunctionsProtocolPrivate, ScoringFunctions), + Api.eval_tasks: (EvalTasksProtocolPrivate, EvalTasks), } diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index b3ebd1368..57e81ac30 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -12,6 +12,7 @@ from llama_stack.distribution.store import DistributionRegistry from .routing_tables import ( DatasetsRoutingTable, + EvalTasksRoutingTable, MemoryBanksRoutingTable, ModelsRoutingTable, ScoringFunctionsRoutingTable, @@ -31,6 +32,7 @@ async def get_routing_table_impl( "shields": ShieldsRoutingTable, "datasets": DatasetsRoutingTable, "scoring_functions": ScoringFunctionsRoutingTable, + "eval_tasks": EvalTasksRoutingTable, } if api.value not in api_to_tables: @@ -44,6 +46,7 @@ async def get_routing_table_impl( async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any: from .routers import ( DatasetIORouter, + EvalRouter, InferenceRouter, MemoryRouter, SafetyRouter, @@ -56,6 +59,7 @@ async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> "safety": SafetyRouter, "datasetio": DatasetIORouter, "scoring": ScoringRouter, + "eval": EvalRouter, } if api.value not in api_to_routers: raise ValueError(f"API {api.value} not found in router map") diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 760dbaf2f..8edf950b2 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -14,6 +14,7 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403 +from llama_stack.apis.eval import * # noqa: F403 class MemoryRouter(Memory): @@ -211,16 +212,16 @@ class ScoringRouter(Scoring): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: res = {} - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score_batch( dataset_id=dataset_id, - scoring_functions=[fn_identifier], + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) @@ -232,17 +233,87 @@ class ScoringRouter(Scoring): ) async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} # look up and map each scoring function to its provider impl - for fn_identifier in scoring_functions: + for fn_identifier in scoring_functions.keys(): score_response = await self.routing_table.get_provider_impl( fn_identifier ).score( input_rows=input_rows, - scoring_functions=[fn_identifier], + scoring_functions={fn_identifier: scoring_functions[fn_identifier]}, ) res.update(score_response.results) return ScoreResponse(results=res) + + +class EvalRouter(Eval): + def __init__( + self, + routing_table: RoutingTable, + ) -> None: + self.routing_table = routing_table + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def run_eval( + self, + task_id: str, + task_config: AppEvalTaskConfig, + ) -> Job: + return await self.routing_table.get_provider_impl(task_id).run_eval( + task_id=task_id, + task_config=task_config, + ) + + @webmethod(route="/eval/evaluate_rows", method="POST") + async def evaluate_rows( + self, + task_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + task_config: EvalTaskConfig, + ) -> EvaluateResponse: + return await self.routing_table.get_provider_impl(task_id).evaluate_rows( + task_id=task_id, + input_rows=input_rows, + scoring_functions=scoring_functions, + task_config=task_config, + ) + + async def job_status( + self, + task_id: str, + job_id: str, + ) -> Optional[JobStatus]: + return await self.routing_table.get_provider_impl(task_id).job_status( + task_id, job_id + ) + + async def job_cancel( + self, + task_id: str, + job_id: str, + ) -> None: + await self.routing_table.get_provider_impl(task_id).job_cancel( + task_id, + job_id, + ) + + async def job_result( + self, + task_id: str, + job_id: str, + ) -> EvaluateResponse: + return await self.routing_table.get_provider_impl(task_id).job_result( + task_id, + job_id, + ) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index bcf125bec..a676b5fef 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -12,6 +12,8 @@ from llama_stack.apis.models import * # noqa: F403 from llama_stack.apis.shields import * # noqa: F403 from llama_stack.apis.memory_banks import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403 +from llama_stack.apis.eval_tasks import * # noqa: F403 + from llama_stack.distribution.store import DistributionRegistry from llama_stack.distribution.datatypes import * # noqa: F403 @@ -40,6 +42,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None: await p.register_dataset(obj) elif api == Api.scoring: await p.register_scoring_function(obj) + elif api == Api.eval: + await p.register_eval_task(obj) else: raise ValueError(f"Unknown API {api} for registering object with provider") @@ -103,6 +107,11 @@ class CommonRoutingTableImpl(RoutingTable): scoring_functions = await p.list_scoring_functions() await add_objects(scoring_functions, pid, ScoringFnDefWithProvider) + elif api == Api.eval: + p.eval_task_store = self + eval_tasks = await p.list_eval_tasks() + await add_objects(eval_tasks, pid, EvalTaskDefWithProvider) + async def shutdown(self) -> None: for p in self.impls_by_provider_id.values(): await p.shutdown() @@ -121,6 +130,8 @@ class CommonRoutingTableImpl(RoutingTable): return ("DatasetIO", "dataset") elif isinstance(self, ScoringFunctionsRoutingTable): return ("Scoring", "scoring_function") + elif isinstance(self, EvalTasksRoutingTable): + return ("Eval", "eval_task") else: raise ValueError("Unknown routing table type") @@ -246,9 +257,9 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): await self.register_object(dataset_def) -class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): +class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - return await self.get_all_with_type("scoring_function") + return await self.get_all_with_type("scoring_fn") async def get_scoring_function( self, name: str @@ -259,3 +270,14 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): self, function_def: ScoringFnDefWithProvider ) -> None: await self.register_object(function_def) + + +class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks): + async def list_eval_tasks(self) -> List[ScoringFnDefWithProvider]: + return await self.get_all_with_type("eval_task") + + async def get_eval_task(self, name: str) -> Optional[EvalTaskDefWithProvider]: + return await self.get_object_by_identifier(name) + + async def register_eval_task(self, eval_task_def: EvalTaskDefWithProvider) -> None: + await self.register_object(eval_task_def) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 59c5a38fa..0f82ca592 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -12,6 +12,7 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import DatasetDef +from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.memory_banks import MemoryBankDef from llama_stack.apis.models import ModelDef from llama_stack.apis.scoring_functions import ScoringFnDef @@ -35,6 +36,7 @@ class Api(Enum): memory_banks = "memory_banks" datasets = "datasets" scoring_functions = "scoring_functions" + eval_tasks = "eval_tasks" # built-in API inspect = "inspect" @@ -70,6 +72,12 @@ class ScoringFunctionsProtocolPrivate(Protocol): async def register_scoring_function(self, function_def: ScoringFnDef) -> None: ... +class EvalTasksProtocolPrivate(Protocol): + async def list_eval_tasks(self) -> List[EvalTaskDef]: ... + + async def register_eval_task(self, eval_task_def: EvalTaskDef) -> None: ... + + @json_schema_type class ProviderSpec(BaseModel): api: Api diff --git a/llama_stack/providers/inline/meta_reference/eval/eval.py b/llama_stack/providers/inline/meta_reference/eval/eval.py index 3aec6170f..4a61c9d93 100644 --- a/llama_stack/providers/inline/meta_reference/eval/eval.py +++ b/llama_stack/providers/inline/meta_reference/eval/eval.py @@ -6,13 +6,15 @@ from enum import Enum from llama_models.llama3.api.datatypes import * # noqa: F403 +from .....apis.common.job_types import Job +from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.common.job_types import Job from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets -from llama_stack.apis.eval import Eval, EvalCandidate, EvaluateResponse, JobStatus +from llama_stack.apis.eval_tasks import EvalTaskDef from llama_stack.apis.inference import Inference from llama_stack.apis.scoring import Scoring +from llama_stack.providers.datatypes import EvalTasksProtocolPrivate from .config import MetaReferenceEvalConfig @@ -25,7 +27,7 @@ class ColumnName(Enum): generated_answer = "generated_answer" -class MetaReferenceEvalImpl(Eval): +class MetaReferenceEvalImpl(Eval, EvalTasksProtocolPrivate): def __init__( self, config: MetaReferenceEvalConfig, @@ -43,10 +45,18 @@ class MetaReferenceEvalImpl(Eval): # TODO: assume sync job, will need jobs API for async scheduling self.jobs = {} + self.eval_tasks = {} + async def initialize(self) -> None: ... async def shutdown(self) -> None: ... + async def register_eval_task(self, task_def: EvalTaskDef) -> None: + self.eval_tasks[task_def.identifier] = task_def + + async def list_eval_tasks(self) -> List[EvalTaskDef]: + return list(self.eval_tasks.values()) + async def validate_eval_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: @@ -70,21 +80,26 @@ class MetaReferenceEvalImpl(Eval): f"Dataset {dataset_id} does not have a correct input schema in {expected_schemas}" ) - async def evaluate_batch( + async def run_eval( self, - dataset_id: str, - candidate: EvalCandidate, - scoring_functions: List[str], + task_id: str, + task_config: EvalTaskConfig, ) -> Job: + task_def = self.eval_tasks[task_id] + dataset_id = task_def.dataset_id + candidate = task_config.eval_candidate + scoring_functions = task_def.scoring_functions + await self.validate_eval_input_dataset_schema(dataset_id=dataset_id) all_rows = await self.datasetio_api.get_rows_paginated( dataset_id=dataset_id, rows_in_page=-1, ) - res = await self.evaluate( + res = await self.evaluate_rows( + task_id=task_id, input_rows=all_rows.rows, - candidate=candidate, scoring_functions=scoring_functions, + task_config=task_config, ) # TODO: currently needs to wait for generation before returning @@ -93,12 +108,14 @@ class MetaReferenceEvalImpl(Eval): self.jobs[job_id] = res return Job(job_id=job_id) - async def evaluate( + async def evaluate_rows( self, + task_id: str, input_rows: List[Dict[str, Any]], - candidate: EvalCandidate, scoring_functions: List[str], + task_config: EvalTaskConfig, ) -> EvaluateResponse: + candidate = task_config.eval_candidate if candidate.type == "agent": raise NotImplementedError( "Evaluation with generation has not been implemented for agents" @@ -122,7 +139,10 @@ class MetaReferenceEvalImpl(Eval): } ) elif ColumnName.chat_completion_input.value in x: - input_messages = eval(str(x[ColumnName.chat_completion_input.value])) + chat_completion_input_str = str( + x[ColumnName.chat_completion_input.value] + ) + input_messages = eval(chat_completion_input_str) input_messages = [UserMessage(**x) for x in input_messages] messages = [] if candidate.system_message: @@ -147,23 +167,33 @@ class MetaReferenceEvalImpl(Eval): for input_r, generated_r in zip(input_rows, generations) ] + if task_config.type == "app" and task_config.scoring_params is not None: + scoring_functions_dict = { + scoring_fn_id: task_config.scoring_params.get(scoring_fn_id, None) + for scoring_fn_id in scoring_functions + } + else: + scoring_functions_dict = { + scoring_fn_id: None for scoring_fn_id in scoring_functions + } + score_response = await self.scoring_api.score( - input_rows=score_input_rows, scoring_functions=scoring_functions + input_rows=score_input_rows, scoring_functions=scoring_functions_dict ) return EvaluateResponse(generations=generations, scores=score_response.results) - async def job_status(self, job_id: str) -> Optional[JobStatus]: + async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: if job_id in self.jobs: return JobStatus.completed return None - async def job_cancel(self, job_id: str) -> None: + async def job_cancel(self, task_id: str, job_id: str) -> None: raise NotImplementedError("Job cancel is not implemented yet") - async def job_result(self, job_id: str) -> EvaluateResponse: - status = await self.job_status(job_id) + async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: + status = await self.job_status(task_id, job_id) if not status or status != JobStatus.completed: raise ValueError(f"Job is not completed, Status: {status.value}") diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring.py b/llama_stack/providers/inline/meta_reference/scoring/scoring.py index 709b2f0c6..c4add966d 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring.py @@ -74,8 +74,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): return scoring_fn_defs_list async def register_scoring_function(self, function_def: ScoringFnDef) -> None: - self.llm_as_judge_fn.register_scoring_fn_def(function_def) - self.scoring_fn_id_impls[function_def.identifier] = self.llm_as_judge_fn + raise NotImplementedError("Register scoring function not implemented yet") async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None: dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id) @@ -97,7 +96,7 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): async def score_batch( self, dataset_id: str, - scoring_functions: List[str], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, save_results_dataset: bool = False, ) -> ScoreBatchResponse: await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id) @@ -106,7 +105,8 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): rows_in_page=-1, ) res = await self.score( - input_rows=all_rows.rows, scoring_functions=scoring_functions + input_rows=all_rows.rows, + scoring_functions=scoring_functions, ) if save_results_dataset: # TODO: persist and register dataset on to server for reading @@ -118,14 +118,19 @@ class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate): ) async def score( - self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] + self, + input_rows: List[Dict[str, Any]], + scoring_functions: Dict[str, Optional[ScoringFnParams]] = None, ) -> ScoreResponse: res = {} - for scoring_fn_id in scoring_functions: + for scoring_fn_id in scoring_functions.keys(): if scoring_fn_id not in self.scoring_fn_id_impls: raise ValueError(f"Scoring function {scoring_fn_id} is not supported.") scoring_fn = self.scoring_fn_id_impls[scoring_fn_id] - score_results = await scoring_fn.score(input_rows, scoring_fn_id) + scoring_fn_params = scoring_functions.get(scoring_fn_id, None) + score_results = await scoring_fn.score( + input_rows, scoring_fn_id, scoring_fn_params + ) agg_results = await scoring_fn.aggregate(score_results) res[scoring_fn_id] = ScoringResult( score_rows=score_results, diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py index cbd875be6..532686ebd 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/base_scoring_fn.py @@ -36,7 +36,10 @@ class BaseScoringFn(ABC): @abstractmethod async def score_row( - self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None + self, + input_row: Dict[str, Any], + scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: raise NotImplementedError() @@ -50,8 +53,9 @@ class BaseScoringFn(ABC): self, input_rows: List[Dict[str, Any]], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> List[ScoringResultRow]: return [ - await self.score_row(input_row, scoring_fn_identifier) + await self.score_row(input_row, scoring_fn_identifier, scoring_params) for input_row in input_rows ] diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py index 2a0cd0578..07405d56c 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/equality_scoring_fn.py @@ -35,6 +35,7 @@ class EqualityScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "equality", + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert "expected_answer" in input_row, "Expected answer not found in input row." assert ( diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py index 20a67edc7..cfef52160 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/fn_defs/llm_as_judge_8b_correctness.py @@ -28,9 +28,13 @@ llm_as_judge_8b_correctness = ScoringFnDef( description="Llm As Judge Scoring Function", parameters=[], return_type=NumberType(), - context=LLMAsJudgeContext( + params=LLMAsJudgeScoringFnParams( prompt_template=JUDGE_PROMPT, judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Total rating: (\d+)", r"rating: (\d+)", r"Rating: (\d+)"], + judge_score_regexes=[ + r"Total rating: (\d+)", + r"rating: (\d+)", + r"Rating: (\d+)", + ], ), ) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py index 84dd28fd7..f98f7fb5e 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/llm_as_judge_scoring_fn.py @@ -36,31 +36,37 @@ class LlmAsJudgeScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = None, + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: assert ( scoring_fn_identifier is not None ), "Scoring function identifier not found." fn_def = self.supported_fn_defs_registry[scoring_fn_identifier] - assert fn_def.context is not None, f"LLMAsJudgeContext not found for {fn_def}." + + # override params if scoring_params is provided + if scoring_params is not None: + fn_def.params = scoring_params + + assert fn_def.params is not None, f"LLMAsJudgeparams not found for {fn_def}." assert ( - fn_def.context.prompt_template is not None + fn_def.params.prompt_template is not None ), "LLM Judge prompt_template not found." assert ( - fn_def.context.judge_score_regex is not None - ), "LLM Judge judge_score_regex not found." + fn_def.params.judge_score_regexes is not None + ), "LLM Judge judge_score_regexes not found." input_query = input_row["input_query"] expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] - judge_input_msg = fn_def.context.prompt_template.format( + judge_input_msg = fn_def.params.prompt_template.format( input_query=input_query, expected_answer=expected_answer, generated_answer=generated_answer, ) judge_response = await self.inference_api.chat_completion( - model=fn_def.context.judge_model, + model=fn_def.params.judge_model, messages=[ { "role": "user", @@ -69,10 +75,10 @@ class LlmAsJudgeScoringFn(BaseScoringFn): ], ) content = judge_response.completion_message.content - rating_regexs = fn_def.context.judge_score_regex + rating_regexes = fn_def.params.judge_score_regexes judge_rating = None - for regex in rating_regexs: + for regex in rating_regexes: match = re.search(regex, content) if match: judge_rating = int(match.group(1)) diff --git a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py index f42964c1f..289c63dd7 100644 --- a/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py +++ b/llama_stack/providers/inline/meta_reference/scoring/scoring_fn/subset_of_scoring_fn.py @@ -34,6 +34,7 @@ class SubsetOfScoringFn(BaseScoringFn): self, input_row: Dict[str, Any], scoring_fn_identifier: Optional[str] = "subset_of", + scoring_params: Optional[ScoringFnParams] = None, ) -> ScoringResultRow: expected_answer = input_row["expected_answer"] generated_answer = input_row["generated_answer"] diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 2278e1a6c..3bec2d11d 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -153,4 +153,7 @@ pytest_plugins = [ "llama_stack.providers.tests.safety.fixtures", "llama_stack.providers.tests.memory.fixtures", "llama_stack.providers.tests.agents.fixtures", + "llama_stack.providers.tests.datasetio.fixtures", + "llama_stack.providers.tests.scoring.fixtures", + "llama_stack.providers.tests.eval.fixtures", ] diff --git a/llama_stack/providers/tests/datasetio/conftest.py b/llama_stack/providers/tests/datasetio/conftest.py new file mode 100644 index 000000000..740eddb33 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from .fixtures import DATASETIO_FIXTURES + + +def pytest_configure(config): + for fixture_name in DATASETIO_FIXTURES: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_generate_tests(metafunc): + if "datasetio_stack" in metafunc.fixturenames: + metafunc.parametrize( + "datasetio_stack", + [ + pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) + for fixture_name in DATASETIO_FIXTURES + ], + indirect=True, + ) diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py new file mode 100644 index 000000000..7d7615b55 --- /dev/null +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -0,0 +1,48 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def datasetio_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def datasetio_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +DATASETIO_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def datasetio_stack(request): + fixture_name = request.param + fixture = request.getfixturevalue(f"datasetio_{fixture_name}") + + impls = await resolve_impls_for_test_v2( + [Api.datasetio], + {"datasetio": fixture.providers}, + fixture.provider_data, + ) + + return impls[Api.datasetio], impls[Api.datasets] diff --git a/llama_stack/providers/tests/datasetio/provider_config_example.yaml b/llama_stack/providers/tests/datasetio/provider_config_example.yaml deleted file mode 100644 index c0565a39e..000000000 --- a/llama_stack/providers/tests/datasetio/provider_config_example.yaml +++ /dev/null @@ -1,4 +0,0 @@ -providers: - - provider_id: test-meta - provider_type: meta-reference - config: {} diff --git a/llama_stack/providers/tests/datasetio/test_datasetio.py b/llama_stack/providers/tests/datasetio/test_datasetio.py index 866b1e270..c02794c50 100644 --- a/llama_stack/providers/tests/datasetio/test_datasetio.py +++ b/llama_stack/providers/tests/datasetio/test_datasetio.py @@ -3,11 +3,10 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + import os import pytest -import pytest_asyncio - from llama_stack.apis.common.type_system import * # noqa: F403 from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403 @@ -15,35 +14,11 @@ import base64 import mimetypes from pathlib import Path -from llama_stack.providers.tests.resolver import resolve_impls_for_test - # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/datasetio/test_datasetio.py \ -# --tb=short --disable-warnings -# ``` - - -@pytest_asyncio.fixture(scope="session") -async def datasetio_settings(): - impls = await resolve_impls_for_test( - Api.datasetio, - ) - return { - "datasetio_impl": impls[Api.datasetio], - "datasets_impl": impls[Api.datasets], - } +# pytest llama_stack/providers/tests/datasetio/test_datasetio.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings def data_url_from_file(file_path: str) -> str: @@ -82,8 +57,7 @@ async def register_dataset( dataset = DatasetDefWithProvider( identifier=dataset_id, - provider_id=os.environ.get("DATASETIO_PROVIDER_ID", None) - or os.environ["PROVIDER_ID"], + provider_id="", url=URL( uri=test_url, ), @@ -92,57 +66,47 @@ async def register_dataset( await datasets_impl.register_dataset(dataset) -@pytest.mark.asyncio -async def test_datasets_list(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 0 +class TestDatasetIO: + @pytest.mark.asyncio + async def test_datasets_list(self, datasetio_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, datasets_impl = datasetio_stack + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 0 + @pytest.mark.asyncio + async def test_register_dataset(self, datasetio_stack): + _, datasets_impl = datasetio_stack + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert isinstance(response, list) + assert len(response) == 1 + assert response[0].identifier == "test_dataset" -@pytest.mark.asyncio -async def test_datasets_register(datasetio_settings): - # NOTE: this needs you to ensure that you are starting from a clean state - # but so far we don't have an unregister API unfortunately, so be careful - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) + @pytest.mark.asyncio + async def test_get_rows_paginated(self, datasetio_stack): + datasetio_impl, datasets_impl = datasetio_stack + await register_dataset(datasets_impl) + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 3 + assert response.next_page_token == "3" - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 + provider = datasetio_impl.routing_table.get_provider_impl("test_dataset") + if provider.__provider_spec__.provider_type == "remote": + pytest.skip("remote provider doesn't support get_rows_paginated") - # register same dataset with same id again will fail - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert isinstance(response, list) - assert len(response) == 1 - assert response[0].identifier == "test_dataset" - - -@pytest.mark.asyncio -async def test_get_rows_paginated(datasetio_settings): - datasetio_impl = datasetio_settings["datasetio_impl"] - datasets_impl = datasetio_settings["datasets_impl"] - await register_dataset(datasets_impl) - - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=3, - ) - - assert isinstance(response.rows, list) - assert len(response.rows) == 3 - assert response.next_page_token == "3" - - # iterate over all rows - response = await datasetio_impl.get_rows_paginated( - dataset_id="test_dataset", - rows_in_page=2, - page_token=response.next_page_token, - ) - - assert isinstance(response.rows, list) - assert len(response.rows) == 2 - assert response.next_page_token == "5" + # iterate over all rows + response = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=2, + page_token=response.next_page_token, + ) + assert isinstance(response.rows, list) + assert len(response.rows) == 2 + assert response.next_page_token == "5" diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py new file mode 100644 index 000000000..064feb611 --- /dev/null +++ b/llama_stack/providers/tests/eval/conftest.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES +from ..inference.fixtures import INFERENCE_FIXTURES +from ..scoring.fixtures import SCORING_FIXTURES +from .fixtures import EVAL_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "fireworks", + }, + id="meta_reference_eval_fireworks_inference", + marks=pytest.mark.meta_reference_eval_fireworks_inference, + ), + pytest.param( + { + "eval": "meta_reference", + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "together", + }, + id="meta_reference_eval_together_inference", + marks=pytest.mark.meta_reference_eval_together_inference, + ), +] + + +def pytest_configure(config): + for fixture_name in [ + "meta_reference_eval_fireworks_inference", + "meta_reference_eval_together_inference", + ]: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "eval_stack" in metafunc.fixturenames: + available_fixtures = { + "eval": EVAL_FIXTURES, + "scoring": SCORING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + "inference": INFERENCE_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("eval_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py new file mode 100644 index 000000000..810239440 --- /dev/null +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def eval_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def eval_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +EVAL_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def eval_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["datasetio", "eval", "scoring", "inference"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.eval, Api.datasetio, Api.inference, Api.scoring], + providers, + provider_data, + ) + + return impls diff --git a/llama_stack/providers/tests/eval/provider_config_example.yaml b/llama_stack/providers/tests/eval/provider_config_example.yaml deleted file mode 100644 index 38f7512f1..000000000 --- a/llama_stack/providers/tests/eval/provider_config_example.yaml +++ /dev/null @@ -1,22 +0,0 @@ -providers: - datasetio: - - provider_id: test-meta - provider_type: meta-reference - config: {} - scoring: - - provider_id: test-meta - provider_type: meta-reference - config: {} - eval: - - provider_id: test-meta - provider_type: meta-reference - config: {} - inference: - - provider_id: test-tgi - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 - - provider_id: test-tgi-2 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5010 diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 667be1bd5..a55a754c5 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -3,81 +3,124 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest -import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.apis.eval.eval import ModelCandidate -from llama_stack.distribution.datatypes import * # noqa: F403 + +import pytest from llama_models.llama3.api import SamplingParams +from llama_stack.apis.eval.eval import ( + AppEvalTaskConfig, + EvalTaskDefWithProvider, + ModelCandidate, +) +from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset -from llama_stack.providers.tests.resolver import resolve_impls_for_test + # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/eval/test_eval.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/eval/test_eval.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings -@pytest_asyncio.fixture(scope="session") -async def eval_settings(): - impls = await resolve_impls_for_test( - Api.eval, deps=[Api.datasetio, Api.scoring, Api.inference] - ) - return { - "eval_impl": impls[Api.eval], - "scoring_impl": impls[Api.scoring], - "datasets_impl": impls[Api.datasets], - } +class Testeval: + @pytest.mark.asyncio + async def test_eval_tasks_list(self, eval_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + eval_tasks_impl = eval_stack[Api.eval_tasks] + response = await eval_tasks_impl.list_eval_tasks() + assert isinstance(response, list) + assert len(response) == 0 + @pytest.mark.asyncio + async def test_eval_evaluate_rows(self, eval_stack): + eval_impl, eval_tasks_impl, datasetio_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasetio], + eval_stack[Api.datasets], + ) + await register_dataset( + datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" + ) + response = await datasets_impl.list_datasets() + assert len(response) == 1 + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset_for_eval", + rows_in_page=3, + ) + assert len(rows.rows) == 3 -@pytest.mark.asyncio -async def test_eval(eval_settings): - datasets_impl = eval_settings["datasets_impl"] - await register_dataset( - datasets_impl, - for_generation=True, - dataset_id="test_dataset_for_eval", - ) - - response = await datasets_impl.list_datasets() - assert len(response) == 1 - - eval_impl = eval_settings["eval_impl"] - response = await eval_impl.evaluate_batch( - dataset_id=response[0].identifier, - candidate=ModelCandidate( - model="Llama3.2-1B-Instruct", - sampling_params=SamplingParams(), - ), - scoring_functions=[ - "meta-reference::subset_of", + scoring_functions = [ "meta-reference::llm_as_judge_8b_correctness", - ], - ) - assert response.job_id == "0" - job_status = await eval_impl.job_status(response.job_id) + "meta-reference::equality", + ] + task_id = "meta-reference::app_eval" + task_def = EvalTaskDefWithProvider( + identifier=task_id, + dataset_id="test_dataset_for_eval", + scoring_functions=scoring_functions, + provider_id="meta-reference", + ) + await eval_tasks_impl.register_eval_task(task_def) - assert job_status and job_status.value == "completed" + response = await eval_impl.evaluate_rows( + task_id=task_id, + input_rows=rows.rows, + scoring_functions=scoring_functions, + task_config=AppEvalTaskConfig( + eval_candidate=ModelCandidate( + model="Llama3.2-3B-Instruct", + sampling_params=SamplingParams(), + ), + ), + ) + assert len(response.generations) == 3 + assert "meta-reference::llm_as_judge_8b_correctness" in response.scores + assert "meta-reference::equality" in response.scores - eval_response = await eval_impl.job_result(response.job_id) + @pytest.mark.asyncio + async def test_eval_run_eval(self, eval_stack): + eval_impl, eval_tasks_impl, datasets_impl = ( + eval_stack[Api.eval], + eval_stack[Api.eval_tasks], + eval_stack[Api.datasets], + ) + await register_dataset( + datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval" + ) - assert eval_response is not None - assert len(eval_response.generations) == 5 - assert "meta-reference::subset_of" in eval_response.scores - assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores + scoring_functions = [ + "meta-reference::llm_as_judge_8b_correctness", + "meta-reference::subset_of", + ] + + task_id = "meta-reference::app_eval-2" + task_def = EvalTaskDefWithProvider( + identifier=task_id, + dataset_id="test_dataset_for_eval", + scoring_functions=scoring_functions, + provider_id="meta-reference", + ) + await eval_tasks_impl.register_eval_task(task_def) + response = await eval_impl.run_eval( + task_id=task_id, + task_config=AppEvalTaskConfig( + eval_candidate=ModelCandidate( + model="Llama3.2-3B-Instruct", + sampling_params=SamplingParams(), + ), + ), + ) + assert response.job_id == "0" + job_status = await eval_impl.job_status(task_id, response.job_id) + assert job_status and job_status.value == "completed" + eval_response = await eval_impl.job_result(task_id, response.job_id) + + assert eval_response is not None + assert len(eval_response.generations) == 5 + assert "meta-reference::subset_of" in eval_response.scores + assert "meta-reference::llm_as_judge_8b_correctness" in eval_response.scores diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 5b047549b..1698d7584 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -64,6 +64,7 @@ def inference_ollama(inference_model) -> ProviderFixture: inference_model = ( [inference_model] if isinstance(inference_model, str) else inference_model ) + print("!!!", inference_model) if "Llama3.1-8B-Instruct" in inference_model: pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py new file mode 100644 index 000000000..ee578f9b3 --- /dev/null +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from ..conftest import get_provider_fixture_overrides + +from ..datasetio.fixtures import DATASETIO_FIXTURES +from ..inference.fixtures import INFERENCE_FIXTURES +from .fixtures import SCORING_FIXTURES + +DEFAULT_PROVIDER_COMBINATIONS = [ + pytest.param( + { + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "fireworks", + }, + id="meta_reference_scoring_fireworks_inference", + marks=pytest.mark.meta_reference_scoring_fireworks_inference, + ), + pytest.param( + { + "scoring": "meta_reference", + "datasetio": "meta_reference", + "inference": "together", + }, + id="meta_reference_scoring_together_inference", + marks=pytest.mark.meta_reference_scoring_together_inference, + ), +] + + +def pytest_configure(config): + for fixture_name in [ + "meta_reference_scoring_fireworks_inference", + "meta_reference_scoring_together_inference", + ]: + config.addinivalue_line( + "markers", + f"{fixture_name}: marks tests as {fixture_name} specific", + ) + + +def pytest_addoption(parser): + parser.addoption( + "--inference-model", + action="store", + default="Llama3.2-3B-Instruct", + help="Specify the inference model to use for testing", + ) + + +def pytest_generate_tests(metafunc): + if "scoring_stack" in metafunc.fixturenames: + available_fixtures = { + "scoring": SCORING_FIXTURES, + "datasetio": DATASETIO_FIXTURES, + "inference": INFERENCE_FIXTURES, + } + combinations = ( + get_provider_fixture_overrides(metafunc.config, available_fixtures) + or DEFAULT_PROVIDER_COMBINATIONS + ) + metafunc.parametrize("scoring_stack", combinations, indirect=True) diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py new file mode 100644 index 000000000..925f98779 --- /dev/null +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest +import pytest_asyncio + +from llama_stack.distribution.datatypes import Api, Provider + +from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def scoring_remote() -> ProviderFixture: + return remote_stack_fixture() + + +@pytest.fixture(scope="session") +def scoring_meta_reference() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="meta-reference", + provider_type="meta-reference", + config={}, + ) + ], + ) + + +SCORING_FIXTURES = ["meta_reference", "remote"] + + +@pytest_asyncio.fixture(scope="session") +async def scoring_stack(request): + fixture_dict = request.param + + providers = {} + provider_data = {} + for key in ["datasetio", "scoring", "inference"]: + fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}") + providers[key] = fixture.providers + if fixture.provider_data: + provider_data.update(fixture.provider_data) + + impls = await resolve_impls_for_test_v2( + [Api.scoring, Api.datasetio, Api.inference], + providers, + provider_data, + ) + + return ( + impls[Api.scoring], + impls[Api.scoring_functions], + impls[Api.datasetio], + impls[Api.datasets], + ) diff --git a/llama_stack/providers/tests/scoring/provider_config_example.yaml b/llama_stack/providers/tests/scoring/provider_config_example.yaml deleted file mode 100644 index 6a9c0d842..000000000 --- a/llama_stack/providers/tests/scoring/provider_config_example.yaml +++ /dev/null @@ -1,17 +0,0 @@ -providers: - datasetio: - - provider_id: test-meta - provider_type: meta-reference - config: {} - scoring: - - provider_id: test-meta - provider_type: meta-reference - config: {} - - provider_id: test-braintrust - provider_type: braintrust - config: {} - inference: - - provider_id: tgi0 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 diff --git a/llama_stack/providers/tests/scoring/test_scoring.py b/llama_stack/providers/tests/scoring/test_scoring.py index b9b920739..3c1b6554f 100644 --- a/llama_stack/providers/tests/scoring/test_scoring.py +++ b/llama_stack/providers/tests/scoring/test_scoring.py @@ -3,150 +3,109 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest -import pytest_asyncio -from llama_stack.apis.common.type_system import * # noqa: F403 -from llama_stack.apis.datasetio import * # noqa: F403 -from llama_stack.distribution.datatypes import * # noqa: F403 + +import pytest + +from llama_stack.apis.scoring_functions import * # noqa: F403 from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset -from llama_stack.providers.tests.resolver import resolve_impls_for_test # How to run this test: # -# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky -# since it depends on the provider you are testing. On top of that you need -# `pytest` and `pytest-asyncio` installed. -# -# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing. -# -# 3. Run: -# -# ```bash -# PROVIDER_ID= \ -# PROVIDER_CONFIG=provider_config.yaml \ -# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \ -# --tb=short --disable-warnings -# ``` +# pytest llama_stack/providers/tests/scoring/test_scoring.py +# -m "meta_reference" +# -v -s --tb=short --disable-warnings -@pytest_asyncio.fixture(scope="session") -async def scoring_settings(): - impls = await resolve_impls_for_test( - Api.scoring, deps=[Api.datasetio, Api.inference] - ) - return { - "scoring_impl": impls[Api.scoring], - "scoring_functions_impl": impls[Api.scoring_functions], - "datasets_impl": impls[Api.datasets], - } +class TestScoring: + @pytest.mark.asyncio + async def test_scoring_functions_list(self, scoring_stack): + # NOTE: this needs you to ensure that you are starting from a clean state + # but so far we don't have an unregister API unfortunately, so be careful + _, scoring_functions_impl, _, _ = scoring_stack + response = await scoring_functions_impl.list_scoring_functions() + assert isinstance(response, list) + assert len(response) > 0 - -@pytest_asyncio.fixture(scope="session") -async def provider_scoring_functions(): - return { - "meta-reference": { - "meta-reference::equality", - "meta-reference::subset_of", - "meta-reference::llm_as_judge_8b_correctness", - }, - "braintrust": { - "braintrust::factuality", - "braintrust::answer-correctness", - }, - } - - -@pytest.mark.asyncio -async def test_scoring_functions_list(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - # get current provider_type we're testing - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - - for x in provider_scoring_functions[provider_type]: - assert x in function_ids - - -@pytest.mark.asyncio -async def test_scoring_functions_register(scoring_settings): - scoring_impl = scoring_settings["scoring_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - datasets_impl = scoring_settings["datasets_impl"] - - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type - if provider_type not in ("meta-reference"): - pytest.skip( - "Other scoring providers don't support registering scoring functions." + @pytest.mark.asyncio + async def test_scoring_score(self, scoring_stack): + scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( + scoring_stack ) + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 - test_prompt = """Output a number between 0 to 10. Your answer must match the format \n Number: """ - # register the scoring function - await scoring_functions_impl.register_scoring_function( - ScoringFnDefWithProvider( - identifier="meta-reference::llm_as_judge_8b_random", - description="Llm As Judge Scoring Function", - parameters=[], - return_type=NumberType(), - context=LLMAsJudgeContext( - prompt_template=test_prompt, - judge_model="Llama3.1-8B-Instruct", - judge_score_regex=[r"Number: (\d+)"], - ), - provider_id="test-meta", + # scoring individual rows + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, ) - ) + assert len(rows.rows) == 3 - scoring_functions = await scoring_functions_impl.list_scoring_functions() - assert isinstance(scoring_functions, list) - assert len(scoring_functions) > 0 - function_ids = [f.identifier for f in scoring_functions] - assert "meta-reference::llm_as_judge_8b_random" in function_ids + scoring_functions = { + "meta-reference::llm_as_judge_8b_correctness": None, + "meta-reference::equality": None, + } + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) - # test score using newly registered scoring function - await register_dataset(datasets_impl) - response = await datasets_impl.list_datasets() - assert len(response) == 1 - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=[ - "meta-reference::llm_as_judge_8b_random", - ], - ) - assert "meta-reference::llm_as_judge_8b_random" in response.results + # score batch + response = await scoring_impl.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5 + @pytest.mark.asyncio + async def test_scoring_score_with_params(self, scoring_stack): + scoring_impl, scoring_functions_impl, datasetio_impl, datasets_impl = ( + scoring_stack + ) + await register_dataset(datasets_impl) + response = await datasets_impl.list_datasets() + assert len(response) == 1 -@pytest.mark.asyncio -async def test_scoring_score(scoring_settings, provider_scoring_functions): - scoring_impl = scoring_settings["scoring_impl"] - datasets_impl = scoring_settings["datasets_impl"] - scoring_functions_impl = scoring_settings["scoring_functions_impl"] - await register_dataset(datasets_impl) + # scoring individual rows + rows = await datasetio_impl.get_rows_paginated( + dataset_id="test_dataset", + rows_in_page=3, + ) + assert len(rows.rows) == 3 - response = await datasets_impl.list_datasets() - assert len(response) == 1 + scoring_functions = { + "meta-reference::llm_as_judge_8b_correctness": LLMAsJudgeScoringFnParams( + judge_model="Llama3.1-405B-Instruct", + prompt_template="Output a number response in the following format: Score: , where is the number between 0 and 9.", + judge_score_regexes=[r"Score: (\d+)"], + ) + } - # get current provider_type we're testing - scoring_functions = await scoring_functions_impl.list_scoring_functions() - function_ids = [f.identifier for f in scoring_functions] - provider = scoring_impl.routing_table.get_provider_impl(function_ids[0]) - provider_type = provider.__provider_spec__.provider_type + response = await scoring_impl.score( + input_rows=rows.rows, + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == len(rows.rows) - response = await scoring_impl.score_batch( - dataset_id=response[0].identifier, - scoring_functions=list(provider_scoring_functions[provider_type]), - ) - - assert len(response.results) == len(provider_scoring_functions[provider_type]) - for x in provider_scoring_functions[provider_type]: - assert x in response.results + # score batch + response = await scoring_impl.score_batch( + dataset_id="test_dataset", + scoring_functions=scoring_functions, + ) + assert len(response.results) == len(scoring_functions) + for x in scoring_functions: + assert x in response.results + assert len(response.results[x].score_rows) == 5