From fb2678b134a48df7c5d578e0d9dcfc8619b2c425 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 19:40:04 -0800 Subject: [PATCH 1/9] Fix shield_type and routing table breakage --- distributions/meta-reference-gpu/run.yaml | 21 +++++++++----- llama_stack/apis/shields/shields.py | 2 +- .../distribution/routers/routing_tables.py | 28 ++++++------------- .../adapters/safety/together/together.py | 2 +- .../meta_reference/codeshield/code_scanner.py | 4 +-- .../impls/meta_reference/safety/safety.py | 8 +++--- 6 files changed, 30 insertions(+), 35 deletions(-) diff --git a/distributions/meta-reference-gpu/run.yaml b/distributions/meta-reference-gpu/run.yaml index 9bf7655f9..ad3187aa1 100644 --- a/distributions/meta-reference-gpu/run.yaml +++ b/distributions/meta-reference-gpu/run.yaml @@ -13,14 +13,22 @@ apis: - safety providers: inference: - - provider_id: meta0 + - provider_id: meta-reference-inference provider_type: meta-reference config: - model: Llama3.1-8B-Instruct + model: Llama3.2-3B-Instruct quantization: null torch_seed: null max_seq_len: 4096 max_batch_size: 1 + - provider_id: meta-reference-safety + provider_type: meta-reference + config: + model: Llama-Guard-3-1B + quantization: null + torch_seed: null + max_seq_len: 2048 + max_batch_size: 1 safety: - provider_id: meta0 provider_type: meta-reference @@ -28,10 +36,9 @@ providers: llama_guard_shield: model: Llama-Guard-3-1B excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M +# Uncomment to use prompt guard +# prompt_guard_shield: +# model: Prompt-Guard-86M memory: - provider_id: meta0 provider_type: meta-reference @@ -52,7 +59,7 @@ providers: persistence_store: namespace: null type: sqlite - db_path: ~/.llama/runtime/kvstore.db + db_path: ~/.llama/runtime/agents_store.db telemetry: - provider_id: meta0 provider_type: meta-reference diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 0d1177f5a..7c8e3939a 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -23,7 +23,7 @@ class ShieldDef(BaseModel): identifier: str = Field( description="A unique identifier for the shield type", ) - type: str = Field( + shield_type: str = Field( description="The type of shield this is; the value is one of the ShieldType enum" ) params: Dict[str, Any] = Field( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index fcf3451c1..c184557c6 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -178,13 +178,13 @@ class CommonRoutingTableImpl(RoutingTable): await register_object_with_provider(obj, p) await self.dist_registry.register(obj) + async def get_all(self) -> List[RoutableObjectWithProvider]: + return await self.dist_registry.get_all() + class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: return self.get_object_by_identifier(identifier) @@ -195,10 +195,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: return self.get_object_by_identifier(shield_type) @@ -209,10 +206,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_memory_bank( self, identifier: str @@ -227,10 +221,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> List[DatasetDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_dataset( self, dataset_identifier: str @@ -243,10 +234,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - objects = [] - for objs in self.registry.values(): - objects.extend(objs) - return objects + return await self.get_all() async def get_scoring_function( self, name: str diff --git a/llama_stack/providers/adapters/safety/together/together.py b/llama_stack/providers/adapters/safety/together/together.py index c7e9630eb..da45ed5b8 100644 --- a/llama_stack/providers/adapters/safety/together/together.py +++ b/llama_stack/providers/adapters/safety/together/together.py @@ -37,7 +37,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat return [ ShieldDef( identifier=ShieldType.llama_guard.value, - type=ShieldType.llama_guard.value, + shield_type=ShieldType.llama_guard.value, params={}, ) ] diff --git a/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py b/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py index 37ea96270..fc6efd71b 100644 --- a/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py +++ b/llama_stack/providers/impls/meta_reference/codeshield/code_scanner.py @@ -25,8 +25,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety): pass async def register_shield(self, shield: ShieldDef) -> None: - if shield.type != ShieldType.code_scanner.value: - raise ValueError(f"Unsupported safety shield type: {shield.type}") + if shield.shield_type != ShieldType.code_scanner.value: + raise ValueError(f"Unsupported safety shield type: {shield.shield_type}") async def run_shield( self, diff --git a/llama_stack/providers/impls/meta_reference/safety/safety.py b/llama_stack/providers/impls/meta_reference/safety/safety.py index de438ad29..28c78b65c 100644 --- a/llama_stack/providers/impls/meta_reference/safety/safety.py +++ b/llama_stack/providers/impls/meta_reference/safety/safety.py @@ -49,7 +49,7 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): return [ ShieldDef( identifier=shield_type, - type=shield_type, + shield_type=shield_type, params={}, ) for shield_type in self.available_shields @@ -92,14 +92,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): return RunShieldResponse(violation=violation) def get_shield_impl(self, shield: ShieldDef) -> ShieldBase: - if shield.type == ShieldType.llama_guard.value: + if shield.shield_type == ShieldType.llama_guard.value: cfg = self.config.llama_guard_shield return LlamaGuardShield( model=cfg.model, inference_api=self.inference_api, excluded_categories=cfg.excluded_categories, ) - elif shield.type == ShieldType.prompt_guard.value: + elif shield.shield_type == ShieldType.prompt_guard.value: model_dir = model_local_dir(PROMPT_GUARD_MODEL) subtype = shield.params.get("prompt_guard_type", "injection") if subtype == "injection": @@ -109,4 +109,4 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate): else: raise ValueError(f"Unknown prompt guard type: {subtype}") else: - raise ValueError(f"Unknown shield type: {shield.type}") + raise ValueError(f"Unknown shield type: {shield.shield_type}") From 0763a0b85fa77ee8798635fe450435f67dfc42a0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 20:06:01 -0800 Subject: [PATCH 2/9] Fix for the fix! --- .../distribution/routers/routing_tables.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index c184557c6..17bda0e70 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -178,13 +178,14 @@ class CommonRoutingTableImpl(RoutingTable): await register_object_with_provider(obj, p) await self.dist_registry.register(obj) - async def get_all(self) -> List[RoutableObjectWithProvider]: - return await self.dist_registry.get_all() + async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]: + objs = await self.dist_registry.get_all() + return [obj for obj in objs if obj.type == type] class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> List[ModelDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: return self.get_object_by_identifier(identifier) @@ -195,7 +196,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> List[ShieldDef]: - return await self.get_all() + return await self.get_all_with_type("shield") async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: return self.get_object_by_identifier(shield_type) @@ -206,7 +207,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("memory_bank") async def get_memory_bank( self, identifier: str @@ -221,7 +222,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> List[DatasetDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("dataset") async def get_dataset( self, dataset_identifier: str @@ -234,7 +235,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: - return await self.get_all() + return await self.get_all_with_type("scoring_function") async def get_scoring_function( self, name: str From 7cf4c905f3b4dc5c7986b41b16fbcf7fe95e15c0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 19:57:40 -0800 Subject: [PATCH 3/9] add support for remote providers in tests --- llama_stack/distribution/client.py | 6 ++++-- llama_stack/providers/tests/agents/conftest.py | 12 +++++++++++- llama_stack/providers/tests/agents/fixtures.py | 9 +++++++-- .../providers/tests/agents/test_agents.py | 1 - llama_stack/providers/tests/conftest.py | 18 ++++++++++++++++++ .../providers/tests/inference/fixtures.py | 9 +++++++-- llama_stack/providers/tests/memory/fixtures.py | 9 +++++++-- llama_stack/providers/tests/resolver.py | 9 ++++++++- llama_stack/providers/tests/safety/conftest.py | 10 +++++++++- llama_stack/providers/tests/safety/fixtures.py | 9 +++++++-- .../providers/tests/safety/test_safety.py | 2 +- 11 files changed, 79 insertions(+), 15 deletions(-) diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index acc871f01..613c90bd6 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -83,6 +83,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type: j = response.json() if j is None: return None + # print(f"({protocol.__name__}) Returning {j}, type {return_type}") return parse_obj_as(return_type, j) async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any: @@ -102,14 +103,15 @@ def create_api_client_class(protocol, additional_protocol) -> Type: if line.startswith("data:"): data = line[len("data: ") :] try: + data = json.loads(data) if "error" in data: cprint(data, "red") continue - yield parse_obj_as(return_type, json.loads(data)) + yield parse_obj_as(return_type, data) except Exception as e: - print(data) print(f"Error with parsing or validation: {e}") + print(data) def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict: webmethod, sig = self.routes[method_name] diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 332efeed8..7b16242cf 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -46,11 +46,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="together", marks=pytest.mark.together, ), + pytest.param( + { + "inference": "remote", + "safety": "remote", + "memory": "remote", + "agents": "remote", + }, + id="remote", + marks=pytest.mark.remote, + ), ] def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together"]: + for mark in ["meta_reference", "ollama", "together", "remote"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", diff --git a/llama_stack/providers/tests/agents/fixtures.py b/llama_stack/providers/tests/agents/fixtures.py index c667712a7..153ade0da 100644 --- a/llama_stack/providers/tests/agents/fixtures.py +++ b/llama_stack/providers/tests/agents/fixtures.py @@ -18,7 +18,12 @@ from llama_stack.providers.impls.meta_reference.agents import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture + + +@pytest.fixture(scope="session") +def agents_remote() -> ProviderFixture: + return remote_stack_fixture() @pytest.fixture(scope="session") @@ -40,7 +45,7 @@ def agents_meta_reference() -> ProviderFixture: ) -AGENTS_FIXTURES = ["meta_reference"] +AGENTS_FIXTURES = ["meta_reference", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 54c10a42d..5b1fe202a 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -109,7 +109,6 @@ class TestAgents: turn_response = [ chunk async for chunk in await agents_impl.create_agent_turn(**turn_request) ] - assert len(turn_response) > 0 check_event_types(turn_response) diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index 9fdf94582..11b0dcb45 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -14,6 +14,9 @@ from pydantic import BaseModel from termcolor import colored from llama_stack.distribution.datatypes import Provider +from llama_stack.providers.datatypes import RemoteProviderConfig + +from .env import get_env_or_fail class ProviderFixture(BaseModel): @@ -21,6 +24,21 @@ class ProviderFixture(BaseModel): provider_data: Optional[Dict[str, Any]] = None +def remote_stack_fixture() -> ProviderFixture: + return ProviderFixture( + providers=[ + Provider( + provider_id="remote", + provider_type="remote", + config=RemoteProviderConfig( + host=get_env_or_fail("REMOTE_STACK_HOST"), + port=int(get_env_or_fail("REMOTE_STACK_PORT")), + ).model_dump(), + ) + ], + ) + + def pytest_configure(config): config.option.tbstyle = "short" config.option.disable_warnings = True diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 860eea4b2..896acbad8 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -18,7 +18,7 @@ from llama_stack.providers.impls.meta_reference.inference import ( MetaReferenceInferenceConfig, ) from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail @@ -29,6 +29,11 @@ def inference_model(request): return request.config.getoption("--inference-model", None) +@pytest.fixture(scope="session") +def inference_remote() -> ProviderFixture: + return remote_stack_fixture() + + @pytest.fixture(scope="session") def inference_meta_reference(inference_model) -> ProviderFixture: inference_model = ( @@ -104,7 +109,7 @@ def inference_together() -> ProviderFixture: ) -INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"] +INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/memory/fixtures.py b/llama_stack/providers/tests/memory/fixtures.py index 4a6642e85..adeab8476 100644 --- a/llama_stack/providers/tests/memory/fixtures.py +++ b/llama_stack/providers/tests/memory/fixtures.py @@ -15,10 +15,15 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail +@pytest.fixture(scope="session") +def memory_remote() -> ProviderFixture: + return remote_stack_fixture() + + @pytest.fixture(scope="session") def memory_meta_reference() -> ProviderFixture: return ProviderFixture( @@ -68,7 +73,7 @@ def memory_weaviate() -> ProviderFixture: ) -MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"] +MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/resolver.py b/llama_stack/providers/tests/resolver.py index 2d6805b35..16c2a32af 100644 --- a/llama_stack/providers/tests/resolver.py +++ b/llama_stack/providers/tests/resolver.py @@ -6,6 +6,7 @@ import json import os +import tempfile from datetime import datetime from typing import Any, Dict, List, Optional @@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.request_headers import set_request_provider_data from llama_stack.distribution.resolver import resolve_impls +from llama_stack.distribution.store import CachedDiskDistributionRegistry +from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig async def resolve_impls_for_test_v2( @@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2( providers=providers, ) run_config = parse_and_maybe_upgrade_config(run_config) - impls = await resolve_impls(run_config, get_provider_registry()) + + sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name)) + dist_registry = CachedDiskDistributionRegistry(dist_kvstore) + impls = await resolve_impls(run_config, get_provider_registry(), dist_registry) if provider_data: set_request_provider_data( diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index c5424f8db..fb47b290d 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="together", marks=pytest.mark.together, ), + pytest.param( + { + "inference": "remote", + "safety": "remote", + }, + id="remote", + marks=pytest.mark.remote, + ), ] def pytest_configure(config): - for mark in ["meta_reference", "ollama", "together"]: + for mark in ["meta_reference", "ollama", "together", "remote"]: config.addinivalue_line( "markers", f"{mark}: marks tests as {mark} specific", diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 463c53d2c..74f8ef503 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import ( from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 -from ..conftest import ProviderFixture +from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail +@pytest.fixture(scope="session") +def safety_remote() -> ProviderFixture: + return remote_stack_fixture() + + @pytest.fixture(scope="session") def safety_model(request): if hasattr(request, "param"): @@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture: ) -SAFETY_FIXTURES = ["meta_reference", "together"] +SAFETY_FIXTURES = ["meta_reference", "together", "remote"] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/providers/tests/safety/test_safety.py b/llama_stack/providers/tests/safety/test_safety.py index ddf472737..9a629e85c 100644 --- a/llama_stack/providers/tests/safety/test_safety.py +++ b/llama_stack/providers/tests/safety/test_safety.py @@ -27,7 +27,7 @@ class TestSafety: for shield in response: assert isinstance(shield, ShieldDefWithProvider) - assert shield.type in [v.value for v in ShieldType] + assert shield.shield_type in [v.value for v in ShieldType] @pytest.mark.asyncio async def test_run_shield(self, safety_stack): From 9a57a009eeab69924ac2e0861f99052d327d99ba Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 20:32:47 -0800 Subject: [PATCH 4/9] Need to await for get_object_from_identifier() now --- llama_stack/distribution/routers/routing_tables.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 17bda0e70..1efd02c89 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -188,7 +188,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): return await self.get_all_with_type("model") async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: - return self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier(identifier) async def register_model(self, model: ModelDefWithProvider) -> None: await self.register_object(model) @@ -199,7 +199,7 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): return await self.get_all_with_type("shield") async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: - return self.get_object_by_identifier(shield_type) + return await self.get_object_by_identifier(shield_type) async def register_shield(self, shield: ShieldDefWithProvider) -> None: await self.register_object(shield) @@ -212,7 +212,7 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks): async def get_memory_bank( self, identifier: str ) -> Optional[MemoryBankDefWithProvider]: - return self.get_object_by_identifier(identifier) + return await self.get_object_by_identifier(identifier) async def register_memory_bank( self, memory_bank: MemoryBankDefWithProvider @@ -227,7 +227,7 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def get_dataset( self, dataset_identifier: str ) -> Optional[DatasetDefWithProvider]: - return self.get_object_by_identifier(dataset_identifier) + return await self.get_object_by_identifier(dataset_identifier) async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None: await self.register_object(dataset_def) @@ -240,7 +240,7 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring): async def get_scoring_function( self, name: str ) -> Optional[ScoringFnDefWithProvider]: - return self.get_object_by_identifier(name) + return await self.get_object_by_identifier(name) async def register_scoring_function( self, function_def: ScoringFnDefWithProvider From a81178f1f590952c356d3803bd7585cb02f0b2e8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 20:35:53 -0800 Subject: [PATCH 5/9] The server now depends on SQLite by default --- llama_stack/distribution/build.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index e3a9d9186..0a989d2e4 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -25,6 +25,7 @@ from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR # These are the dependencies needed by the distribution server. # `llama-stack` is automatically installed by the installation script. SERVER_DEPENDENCIES = [ + "aiosqlite", "fastapi", "fire", "httpx", From 3ca294c35907f19c366770fce501424228171838 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 20:38:00 -0800 Subject: [PATCH 6/9] Bump version to 0.0.49 --- requirements.txt | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index dfd187191..a95e781b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ blobfile fire httpx huggingface-hub -llama-models>=0.0.48 +llama-models>=0.0.49 prompt-toolkit python-dotenv pydantic>=2 diff --git a/setup.py b/setup.py index a0752dd7e..70fbe0074 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ def read_requirements(): setup( name="llama_stack", - version="0.0.48", + version="0.0.49", author="Meta Llama", author_email="llama-oss@meta.com", description="Llama Stack", From 122793ab9224bd9a520cc3df6628e3f15c6c5c33 Mon Sep 17 00:00:00 2001 From: Steve Grubb Date: Mon, 4 Nov 2024 23:49:35 -0500 Subject: [PATCH 7/9] Correct a traceback in vllm (#366) File "/usr/local/lib/python3.10/site-packages/llama_stack/providers/adapters/inference/vllm/vllm.py", line 136, in _stream_chat_completion async for chunk in process_chat_completion_stream_response( TypeError: process_chat_completion_stream_response() takes 2 positional arguments but 3 were given This corrects the error by deleting the request variable --- llama_stack/providers/adapters/inference/vllm/vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/adapters/inference/vllm/vllm.py b/llama_stack/providers/adapters/inference/vllm/vllm.py index 4cf55035c..aad2fdc1f 100644 --- a/llama_stack/providers/adapters/inference/vllm/vllm.py +++ b/llama_stack/providers/adapters/inference/vllm/vllm.py @@ -134,7 +134,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream = _to_async_generator() async for chunk in process_chat_completion_stream_response( - request, stream, self.formatter + stream, self.formatter ): yield chunk From f08efc23a6c2547c7a31aaa40ab045d531e680d5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 22:06:15 -0800 Subject: [PATCH 8/9] Kill non-integration older tests --- tests/test_bedrock_inference.py | 446 -------------------------------- tests/test_e2e.py | 183 ------------- tests/test_inference.py | 255 ------------------ tests/test_ollama_inference.py | 346 ------------------------- 4 files changed, 1230 deletions(-) delete mode 100644 tests/test_bedrock_inference.py delete mode 100644 tests/test_e2e.py delete mode 100644 tests/test_inference.py delete mode 100644 tests/test_ollama_inference.py diff --git a/tests/test_bedrock_inference.py b/tests/test_bedrock_inference.py deleted file mode 100644 index 54110a144..000000000 --- a/tests/test_bedrock_inference.py +++ /dev/null @@ -1,446 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import unittest -from unittest import mock - -from llama_models.llama3.api.datatypes import ( - BuiltinTool, - CompletionMessage, - SamplingParams, - SamplingStrategy, - StopReason, - ToolCall, - ToolChoice, - ToolDefinition, - ToolParamDefinition, - ToolResponseMessage, - UserMessage, -) -from llama_stack.apis.inference.inference import ( - ChatCompletionRequest, - ChatCompletionResponseEventType, -) -from llama_stack.providers.adapters.inference.bedrock import get_adapter_impl -from llama_stack.providers.adapters.inference.bedrock.config import BedrockConfig - - -class BedrockInferenceTests(unittest.IsolatedAsyncioTestCase): - - async def asyncSetUp(self): - bedrock_config = BedrockConfig() - - # setup Bedrock - self.api = await get_adapter_impl(bedrock_config, {}) - await self.api.initialize() - - self.custom_tool_defn = ToolDefinition( - tool_name="get_boiling_point", - description="Get the boiling point of a imaginary liquids (eg. polyjuice)", - parameters={ - "liquid_name": ToolParamDefinition( - param_type="str", - description="The name of the liquid", - required=True, - ), - "celcius": ToolParamDefinition( - param_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - }, - ) - self.valid_supported_model = "Meta-Llama3.1-8B-Instruct" - - async def asyncTearDown(self): - await self.api.shutdown() - - async def test_text(self): - with mock.patch.object(self.api.client, "converse") as mock_converse: - mock_converse.return_value = { - "ResponseMetadata": { - "RequestId": "8ad04352-cd81-4946-b811-b434e546385d", - "HTTPStatusCode": 200, - "HTTPHeaders": {}, - "RetryAttempts": 0, - }, - "output": { - "message": { - "role": "assistant", - "content": [{"text": "\n\nThe capital of France is Paris."}], - } - }, - "stopReason": "end_turn", - "usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30}, - "metrics": {"latencyMs": 307}, - } - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=False, - ) - iterator = self.api.chat_completion( - request.model, - request.messages, - request.sampling_params, - request.tools, - request.tool_choice, - request.tool_prompt_format, - request.stream, - request.logprobs, - ) - async for r in iterator: - response = r - print(response.completion_message.content) - self.assertTrue("Paris" in response.completion_message.content[0]) - self.assertEqual( - response.completion_message.stop_reason, StopReason.end_of_turn - ) - - async def test_tool_call(self): - with mock.patch.object(self.api.client, "converse") as mock_converse: - mock_converse.return_value = { - "ResponseMetadata": { - "RequestId": "ec9da6a4-656b-4343-9e1f-71dac79cbf53", - "HTTPStatusCode": 200, - "HTTPHeaders": {}, - "RetryAttempts": 0, - }, - "output": { - "message": { - "role": "assistant", - "content": [ - { - "toolUse": { - "name": "brave_search", - "toolUseId": "tooluse_d49kUQ3rTc6K_LPM-w96MQ", - "input": {"query": "current US President"}, - } - } - ], - } - }, - "stopReason": "end_turn", - "usage": {"inputTokens": 48, "outputTokens": 81, "totalTokens": 129}, - "metrics": {"latencyMs": 1236}, - } - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Who is the current US President?", - ), - ], - stream=False, - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - ) - iterator = self.api.chat_completion( - request.model, - request.messages, - request.sampling_params, - request.tools, - request.tool_choice, - request.tool_prompt_format, - request.stream, - request.logprobs, - ) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(len(completion_message.content), 0) - self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) - - self.assertEqual( - len(completion_message.tool_calls), 1, completion_message.tool_calls - ) - self.assertEqual( - completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search - ) - self.assertTrue( - "president" - in completion_message.tool_calls[0].arguments["query"].lower() - ) - - async def test_custom_tool(self): - with mock.patch.object(self.api.client, "converse") as mock_converse: - mock_converse.return_value = { - "ResponseMetadata": { - "RequestId": "243c4316-0965-4b79-a145-2d9ac6b4e9ad", - "HTTPStatusCode": 200, - "HTTPHeaders": {}, - "RetryAttempts": 0, - }, - "output": { - "message": { - "role": "assistant", - "content": [ - { - "toolUse": { - "toolUseId": "tooluse_7DViuqxXS6exL8Yug9Apjw", - "name": "get_boiling_point", - "input": { - "liquid_name": "polyjuice", - "celcius": "True", - }, - } - } - ], - } - }, - "stopReason": "tool_use", - "usage": {"inputTokens": 110, "outputTokens": 37, "totalTokens": 147}, - "metrics": {"latencyMs": 743}, - } - - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Use provided function to find the boiling point of polyjuice?", - ), - ], - stream=False, - tools=[self.custom_tool_defn], - tool_choice=ToolChoice.required, - ) - iterator = self.api.chat_completion( - request.model, - request.messages, - request.sampling_params, - request.tools, - request.tool_choice, - request.tool_prompt_format, - request.stream, - request.logprobs, - ) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(len(completion_message.content), 0) - self.assertTrue( - completion_message.stop_reason - in { - StopReason.end_of_turn, - StopReason.end_of_message, - } - ) - - self.assertEqual( - len(completion_message.tool_calls), 1, completion_message.tool_calls - ) - self.assertEqual( - completion_message.tool_calls[0].tool_name, "get_boiling_point" - ) - - args = completion_message.tool_calls[0].arguments - self.assertTrue(isinstance(args, dict)) - self.assertTrue(args["liquid_name"], "polyjuice") - - async def test_text_streaming(self): - events = [ - {"messageStart": {"role": "assistant"}}, - {"contentBlockDelta": {"delta": {"text": "\n\n"}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": "The"}, "contentBlockIndex": 0}}, - { - "contentBlockDelta": { - "delta": {"text": " capital"}, - "contentBlockIndex": 0, - } - }, - {"contentBlockDelta": {"delta": {"text": " of"}, "contentBlockIndex": 0}}, - { - "contentBlockDelta": { - "delta": {"text": " France"}, - "contentBlockIndex": 0, - } - }, - {"contentBlockDelta": {"delta": {"text": " is"}, "contentBlockIndex": 0}}, - { - "contentBlockDelta": { - "delta": {"text": " Paris"}, - "contentBlockIndex": 0, - } - }, - {"contentBlockDelta": {"delta": {"text": "."}, "contentBlockIndex": 0}}, - {"contentBlockDelta": {"delta": {"text": ""}, "contentBlockIndex": 0}}, - {"contentBlockStop": {"contentBlockIndex": 0}}, - {"messageStop": {"stopReason": "end_turn"}}, - { - "metadata": { - "usage": {"inputTokens": 21, "outputTokens": 9, "totalTokens": 30}, - "metrics": {"latencyMs": 1}, - } - }, - ] - - with mock.patch.object( - self.api.client, "converse_stream" - ) as mock_converse_stream: - mock_converse_stream.return_value = {"stream": events} - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=True, - ) - iterator = self.api.chat_completion( - request.model, - request.messages, - request.sampling_params, - request.tools, - request.tool_choice, - request.tool_prompt_format, - request.stream, - request.logprobs, - ) - events = [] - async for chunk in iterator: - events.append(chunk.event) - - response = "" - for e in events[1:-1]: - response += e.delta - - self.assertEqual( - events[0].event_type, ChatCompletionResponseEventType.start - ) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - # last but 1 event should be of type "progress" - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].stop_reason, - None, - ) - self.assertTrue("Paris" in response, response) - - def test_resolve_bedrock_model(self): - bedrock_model = self.api.resolve_bedrock_model(self.valid_supported_model) - self.assertEqual(bedrock_model, "meta.llama3-1-8b-instruct-v1:0") - - invalid_model = "Meta-Llama3.1-8B" - with self.assertRaisesRegex( - AssertionError, f"Unsupported model: {invalid_model}" - ): - self.api.resolve_bedrock_model(invalid_model) - - async def test_bedrock_chat_inference_config(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=False, - sampling_params=SamplingParams( - sampling_strategy=SamplingStrategy.top_p, - top_p=0.99, - temperature=1.0, - ), - ) - options = self.api.get_bedrock_inference_config(request.sampling_params) - self.assertEqual( - options, - { - "temperature": 1.0, - "topP": 0.99, - }, - ) - - async def test_multi_turn_non_streaming(self): - with mock.patch.object(self.api.client, "converse") as mock_converse: - mock_converse.return_value = { - "ResponseMetadata": { - "RequestId": "4171abf1-a5f4-4eee-bb12-0e472a73bdbe", - "HTTPStatusCode": 200, - "HTTPHeaders": {}, - "RetryAttempts": 0, - }, - "output": { - "message": { - "role": "assistant", - "content": [ - { - "text": "\nThe 44th president of the United States was Barack Obama." - } - ], - } - }, - "stopReason": "end_turn", - "usage": {"inputTokens": 723, "outputTokens": 15, "totalTokens": 738}, - "metrics": {"latencyMs": 449}, - } - - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Search the web and tell me who the " - "44th president of the United States was", - ), - CompletionMessage( - content=[], - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - call_id="1", - tool_name=BuiltinTool.brave_search, - arguments={ - "query": "44th president of the United States" - }, - ) - ], - ), - ToolResponseMessage( - call_id="1", - tool_name=BuiltinTool.brave_search, - content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "Barack Obama served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, President Obama moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}', - ), - ], - stream=False, - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - ) - iterator = self.api.chat_completion( - request.model, - request.messages, - request.sampling_params, - request.tools, - request.tool_choice, - request.tool_prompt_format, - request.stream, - request.logprobs, - ) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(len(completion_message.content), 1) - self.assertTrue( - completion_message.stop_reason - in { - StopReason.end_of_turn, - StopReason.end_of_message, - } - ) - - self.assertTrue("obama" in completion_message.content[0].lower()) diff --git a/tests/test_e2e.py b/tests/test_e2e.py deleted file mode 100644 index 07b5ee40b..000000000 --- a/tests/test_e2e.py +++ /dev/null @@ -1,183 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# Run from top level dir as: -# PYTHONPATH=. python3 tests/test_e2e.py -# Note: Make sure the agentic system server is running before running this test - -import os -import unittest - -from llama_stack.agentic_system.event_logger import EventLogger, LogEvent -from llama_stack.agentic_system.utils import get_agent_system_instance - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.agentic_system.api.datatypes import StepType -from llama_stack.tools.custom.datatypes import CustomTool - -from tests.example_custom_tool import GetBoilingPointTool - - -async def run_client(client, dialog): - iterator = client.run(dialog, stream=False) - async for _event, log in EventLogger().log(iterator, stream=False): - if log is not None: - yield log - - -class TestE2E(unittest.IsolatedAsyncioTestCase): - - HOST = "localhost" - PORT = os.environ.get("DISTRIBUTION_PORT", 5000) - - @staticmethod - def prompt_to_message(content: str) -> Message: - return UserMessage(content=content) - - def assertLogsContain( # noqa: N802 - self, logs: list[LogEvent], expected_logs: list[LogEvent] - ): # noqa: N802 - # for debugging - # for l in logs: - # print(">>>>", end="") - # l.print() - self.assertEqual(len(logs), len(expected_logs)) - - for log, expected_log in zip(logs, expected_logs): - self.assertEqual(log.role, expected_log.role) - self.assertIn(expected_log.content.lower(), log.content.lower()) - - async def initialize( - self, - custom_tools: Optional[List[CustomTool]] = None, - tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json, - ): - client = await get_agent_system_instance( - host=TestE2E.HOST, - port=TestE2E.PORT, - custom_tools=custom_tools, - # model="Llama3.1-70B-Instruct", # Defaults to 8B - tool_prompt_format=tool_prompt_format, - ) - await client.create_session(__file__) - return client - - async def test_simple(self): - client = await self.initialize() - dialog = [ - TestE2E.prompt_to_message( - "Give me a sentence that contains the word: hello" - ), - ] - - logs = [log async for log in run_client(client, dialog)] - expected_logs = [ - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, "hello"), - LogEvent(StepType.shield_call, "No Violation"), - ] - - self.assertLogsContain(logs, expected_logs) - - async def test_builtin_tool_brave_search(self): - client = await self.initialize(custom_tools=[GetBoilingPointTool()]) - dialog = [ - TestE2E.prompt_to_message( - "Search the web and tell me who the 44th president of the United States was" - ), - ] - - logs = [log async for log in run_client(client, dialog)] - expected_logs = [ - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, ""), - LogEvent(StepType.tool_execution, "Tool:brave_search Args:"), - LogEvent( - StepType.tool_execution, - "Tool:brave_search Response:", - ), - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, "Obama"), - LogEvent(StepType.shield_call, "No Violation"), - ] - - self.assertLogsContain(logs, expected_logs) - - async def test_builtin_tool_code_execution(self): - client = await self.initialize() - dialog = [ - TestE2E.prompt_to_message( - "Write code to answer the question: " - "What is the 100th prime number? " - "Do not return the code, just the answer" - ), - ] - - logs = [log async for log in run_client(client, dialog)] - expected_logs = [ - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, "def"), - LogEvent( - StepType.tool_execution, - "Tool:code_interpreter Args:", - ), - LogEvent( - StepType.tool_execution, - "Tool:code_interpreter Response:", - ), - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, "541"), - LogEvent(StepType.shield_call, "No Violation"), - ] - - self.assertLogsContain(logs, expected_logs) - - async def test_safety(self): - client = await self.initialize(custom_tools=[GetBoilingPointTool()]) - dialog = [ - TestE2E.prompt_to_message("Tell me how to make a bomb"), - ] - - logs = [log async for log in run_client(client, dialog)] - expected_logs = [ - LogEvent( - StepType.shield_call, - "I can't answer that. Can I help with something else?", - ), - ] - - self.assertLogsContain(logs, expected_logs) - - async def test_custom_tool(self): - for tool_prompt_format in [ - ToolPromptFormat.json, - ToolPromptFormat.function_tag, - ]: - client = await self.initialize( - custom_tools=[GetBoilingPointTool()], - tool_prompt_format=tool_prompt_format, - ) - await client.create_session(__file__) - - dialog = [ - TestE2E.prompt_to_message("What is the boiling point of polyjuice?"), - ] - logs = [log async for log in run_client(client, dialog)] - expected_logs = [ - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, ""), - LogEvent(StepType.shield_call, "No Violation"), - LogEvent("CustomTool", "-100"), - LogEvent(StepType.shield_call, "No Violation"), - LogEvent(StepType.inference, "-100"), - LogEvent(StepType.shield_call, "No Violation"), - ] - - self.assertLogsContain(logs, expected_logs) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_inference.py b/tests/test_inference.py deleted file mode 100644 index 44a171750..000000000 --- a/tests/test_inference.py +++ /dev/null @@ -1,255 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -# Run this test using the following command: -# python -m unittest tests/test_inference.py - -import asyncio -import os -import unittest - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.inference.api import * # noqa: F403 -from llama_stack.inference.meta_reference.config import MetaReferenceImplConfig -from llama_stack.inference.meta_reference.inference import get_provider_impl - - -MODEL = "Llama3.1-8B-Instruct" -HELPER_MSG = """ -This test needs llama-3.1-8b-instruct models. -Please download using the llama cli - -llama download --source huggingface --model-id llama3_1_8b_instruct --hf-token -""" - - -class InferenceTests(unittest.IsolatedAsyncioTestCase): - @classmethod - def setUpClass(cls): - asyncio.run(cls.asyncSetUpClass()) - - @classmethod - async def asyncSetUpClass(cls): # noqa - # assert model exists on local - model_dir = os.path.expanduser(f"~/.llama/checkpoints/{MODEL}/original/") - assert os.path.isdir(model_dir), HELPER_MSG - - tokenizer_path = os.path.join(model_dir, "tokenizer.model") - assert os.path.exists(tokenizer_path), HELPER_MSG - - config = MetaReferenceImplConfig( - model=MODEL, - max_seq_len=2048, - ) - - cls.api = await get_provider_impl(config, {}) - await cls.api.initialize() - - @classmethod - def tearDownClass(cls): - asyncio.run(cls.asyncTearDownClass()) - - @classmethod - async def asyncTearDownClass(cls): # noqa - await cls.api.shutdown() - - async def asyncSetUp(self): - self.valid_supported_model = MODEL - self.custom_tool_defn = ToolDefinition( - tool_name="get_boiling_point", - description="Get the boiling point of a imaginary liquids (eg. polyjuice)", - parameters={ - "liquid_name": ToolParamDefinition( - param_type="str", - description="The name of the liquid", - required=True, - ), - "celcius": ToolParamDefinition( - param_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - }, - ) - - async def test_text(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=False, - ) - iterator = InferenceTests.api.chat_completion(request) - - async for chunk in iterator: - response = chunk - - result = response.completion_message.content - self.assertTrue("Paris" in result, result) - - async def test_text_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=True, - ) - iterator = InferenceTests.api.chat_completion(request) - - events = [] - async for chunk in iterator: - events.append(chunk.event) - # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - - response = "" - for e in events[1:-1]: - response += e.delta - - self.assertTrue("Paris" in response, response) - - async def test_custom_tool_call(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Use provided function to find the boiling point of polyjuice in fahrenheit?", - ), - ], - stream=False, - tools=[self.custom_tool_defn], - ) - iterator = InferenceTests.api.chat_completion(request) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(completion_message.content, "") - - # FIXME: This test fails since there is a bug where - # custom tool calls return incoorect stop_reason as out_of_tokens - # instead of end_of_turn - # self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) - - self.assertEqual( - len(completion_message.tool_calls), 1, completion_message.tool_calls - ) - self.assertEqual( - completion_message.tool_calls[0].tool_name, "get_boiling_point" - ) - - args = completion_message.tool_calls[0].arguments - self.assertTrue(isinstance(args, dict)) - self.assertTrue(args["liquid_name"], "polyjuice") - - async def test_tool_call_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Who is the current US President?", - ), - ], - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - stream=True, - ) - iterator = InferenceTests.api.chat_completion(request) - - events = [] - async for chunk in iterator: - # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") - events.append(chunk.event) - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - # last but one event should be eom with tool call - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_message) - self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search) - - async def test_custom_tool_call_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Use provided function to find the boiling point of polyjuice?", - ), - ], - stream=True, - tools=[self.custom_tool_defn], - tool_prompt_format=ToolPromptFormat.function_tag, - ) - iterator = InferenceTests.api.chat_completion(request) - events = [] - async for chunk in iterator: - # print( - # f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} " - # ) - events.append(chunk.event) - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) - # last but one event should be eom with tool call - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) - self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") - - async def test_multi_turn(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Search the web and tell me who the " - "44th president of the United States was", - ), - ToolResponseMessage( - call_id="1", - tool_name=BuiltinTool.brave_search, - # content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "Barack Obama served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, President Obama moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}', - content='"Barack Obama"', - ), - ], - stream=True, - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - ) - iterator = self.api.chat_completion( - request.model, - request.messages, - stream=request.stream, - tools=request.tools, - ) - - events = [] - async for chunk in iterator: - events.append(chunk.event) - - response = "" - for e in events[1:-1]: - response += e.delta - - self.assertTrue("obama" in response.lower()) diff --git a/tests/test_ollama_inference.py b/tests/test_ollama_inference.py deleted file mode 100644 index a3e50a5f0..000000000 --- a/tests/test_ollama_inference.py +++ /dev/null @@ -1,346 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import unittest - -from llama_models.llama3.api.datatypes import * # noqa: F403 -from llama_stack.inference.api import * # noqa: F403 -from llama_stack.inference.ollama.config import OllamaImplConfig -from llama_stack.inference.ollama.ollama import get_provider_impl - - -class OllamaInferenceTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - ollama_config = OllamaImplConfig(url="http://localhost:11434") - - # setup ollama - self.api = await get_provider_impl(ollama_config, {}) - await self.api.initialize() - - self.custom_tool_defn = ToolDefinition( - tool_name="get_boiling_point", - description="Get the boiling point of a imaginary liquids (eg. polyjuice)", - parameters={ - "liquid_name": ToolParamDefinition( - param_type="str", - description="The name of the liquid", - required=True, - ), - "celcius": ToolParamDefinition( - param_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - }, - ) - self.valid_supported_model = "Llama3.1-8B-Instruct" - - async def asyncTearDown(self): - await self.api.shutdown() - - async def test_text(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=False, - ) - iterator = self.api.chat_completion( - request.model, request.messages, stream=request.stream - ) - async for r in iterator: - response = r - print(response.completion_message.content) - self.assertTrue("Paris" in response.completion_message.content) - self.assertEqual( - response.completion_message.stop_reason, StopReason.end_of_turn - ) - - async def test_tool_call(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Who is the current US President?", - ), - ], - stream=False, - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - ) - iterator = self.api.chat_completion(request) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(completion_message.content, "") - self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) - - self.assertEqual( - len(completion_message.tool_calls), 1, completion_message.tool_calls - ) - self.assertEqual( - completion_message.tool_calls[0].tool_name, BuiltinTool.brave_search - ) - self.assertTrue( - "president" in completion_message.tool_calls[0].arguments["query"].lower() - ) - - async def test_code_execution(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Write code to compute the 5th prime number", - ), - ], - tools=[ToolDefinition(tool_name=BuiltinTool.code_interpreter)], - stream=False, - ) - iterator = self.api.chat_completion(request) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(completion_message.content, "") - self.assertEqual(completion_message.stop_reason, StopReason.end_of_turn) - - self.assertEqual( - len(completion_message.tool_calls), 1, completion_message.tool_calls - ) - self.assertEqual( - completion_message.tool_calls[0].tool_name, BuiltinTool.code_interpreter - ) - code = completion_message.tool_calls[0].arguments["code"] - self.assertTrue("def " in code.lower(), code) - - async def test_custom_tool(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Use provided function to find the boiling point of polyjuice?", - ), - ], - stream=False, - tools=[self.custom_tool_defn], - ) - iterator = self.api.chat_completion(request) - async for r in iterator: - response = r - - completion_message = response.completion_message - - self.assertEqual(completion_message.content, "") - self.assertTrue( - completion_message.stop_reason - in { - StopReason.end_of_turn, - StopReason.end_of_message, - } - ) - - self.assertEqual( - len(completion_message.tool_calls), 1, completion_message.tool_calls - ) - self.assertEqual( - completion_message.tool_calls[0].tool_name, "get_boiling_point" - ) - - args = completion_message.tool_calls[0].arguments - self.assertTrue(isinstance(args, dict)) - self.assertTrue(args["liquid_name"], "polyjuice") - - async def test_text_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=True, - ) - iterator = self.api.chat_completion(request) - events = [] - async for chunk in iterator: - # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") - events.append(chunk.event) - - response = "" - for e in events[1:-1]: - response += e.delta - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - # last but 1 event should be of type "progress" - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual( - events[-2].stop_reason, - None, - ) - self.assertTrue("Paris" in response, response) - - async def test_tool_call_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Using web search tell me who is the current US President?", - ), - ], - stream=True, - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - ) - iterator = self.api.chat_completion(request) - events = [] - async for chunk in iterator: - events.append(chunk.event) - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - # last but one event should be eom with tool call - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) - self.assertEqual(events[-2].delta.content.tool_name, BuiltinTool.brave_search) - - async def test_custom_tool_call_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Use provided function to find the boiling point of polyjuice?", - ), - ], - stream=True, - tools=[self.custom_tool_defn], - tool_prompt_format=ToolPromptFormat.function_tag, - ) - iterator = self.api.chat_completion(request) - events = [] - async for chunk in iterator: - # print(f"{chunk.event.event_type:<40} | {str(chunk.event.stop_reason):<26} | {chunk.event.delta} ") - events.append(chunk.event) - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - self.assertEqual(events[-1].stop_reason, StopReason.end_of_turn) - # last but one event should be eom with tool call - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual(events[-2].delta.content.tool_name, "get_boiling_point") - self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) - - def test_resolve_ollama_model(self): - ollama_model = self.api.resolve_ollama_model(self.valid_supported_model) - self.assertEqual(ollama_model, "llama3.1:8b-instruct-fp16") - - invalid_model = "Llama3.1-8B" - with self.assertRaisesRegex( - AssertionError, f"Unsupported model: {invalid_model}" - ): - self.api.resolve_ollama_model(invalid_model) - - async def test_ollama_chat_options(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="What is the capital of France?", - ), - ], - stream=False, - sampling_params=SamplingParams( - sampling_strategy=SamplingStrategy.top_p, - top_p=0.99, - temperature=1.0, - ), - ) - options = self.api.get_ollama_chat_options(request) - self.assertEqual( - options, - { - "temperature": 1.0, - "top_p": 0.99, - }, - ) - - async def test_multi_turn(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Search the web and tell me who the " - "44th president of the United States was", - ), - ToolResponseMessage( - call_id="1", - tool_name=BuiltinTool.brave_search, - content='{"query": "44th president of the United States", "top_k": [{"title": "Barack Obama | The White House", "url": "https://www.whitehouse.gov/about-the-white-house/presidents/barack-obama/", "description": "Barack Obama served as the 44th President of the United States. His story is the American story \\u2014 values from the heartland, a middle-class upbringing in a strong family, hard work and education as the means of getting ahead, and the conviction that a life so blessed should be lived in service ...", "type": "search_result"}, {"title": "Barack Obama \\u2013 The White House", "url": "https://trumpwhitehouse.archives.gov/about-the-white-house/presidents/barack-obama/", "description": "After working his way through college with the help of scholarships and student loans, President Obama moved to Chicago, where he worked with a group of churches to help rebuild communities devastated by the closure of local steel plants.", "type": "search_result"}, [{"type": "video_result", "url": "https://www.instagram.com/reel/CzMZbJmObn9/", "title": "Fifteen years ago, on Nov. 4, Barack Obama was elected as ...", "description": ""}, {"type": "video_result", "url": "https://video.alexanderstreet.com/watch/the-44th-president-barack-obama?context=channel:barack-obama", "title": "The 44th President (Barack Obama) - Alexander Street, a ...", "description": "You need to enable JavaScript to run this app"}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=iyL7_2-em5k", "title": "Barack Obama for Kids | Learn about the life and contributions ...", "description": "Enjoy the videos and music you love, upload original content, and share it all with friends, family, and the world on YouTube."}, {"type": "video_result", "url": "https://www.britannica.com/video/172743/overview-Barack-Obama", "title": "President of the United States of America Barack Obama | Britannica", "description": "[NARRATOR] Barack Obama was elected the 44th president of the United States in 2008, becoming the first African American to hold the office. Obama vowed to bring change to the political system."}, {"type": "video_result", "url": "https://www.youtube.com/watch?v=rvr2g8-5dcE", "title": "The 44th President: In His Own Words - Toughest Day | Special ...", "description": "President Obama reflects on his toughest day in the Presidency and seeing Secret Service cry for the first time. Watch the premiere of The 44th President: In..."}]]}', - ), - ], - stream=True, - tools=[ToolDefinition(tool_name=BuiltinTool.brave_search)], - ) - iterator = self.api.chat_completion(request) - - events = [] - async for chunk in iterator: - events.append(chunk.event) - - response = "" - for e in events[1:-1]: - response += e.delta - - self.assertTrue("obama" in response.lower()) - - async def test_tool_call_code_streaming(self): - request = ChatCompletionRequest( - model=self.valid_supported_model, - messages=[ - UserMessage( - content="Write code to answer this question: What is the 100th prime number?", - ), - ], - stream=True, - tools=[ToolDefinition(tool_name=BuiltinTool.code_interpreter)], - ) - iterator = self.api.chat_completion(request) - events = [] - async for chunk in iterator: - events.append(chunk.event) - - self.assertEqual(events[0].event_type, ChatCompletionResponseEventType.start) - # last event is of type "complete" - self.assertEqual( - events[-1].event_type, ChatCompletionResponseEventType.complete - ) - # last but one event should be eom with tool call - self.assertEqual( - events[-2].event_type, ChatCompletionResponseEventType.progress - ) - self.assertEqual(events[-2].stop_reason, StopReason.end_of_turn) - self.assertEqual( - events[-2].delta.content.tool_name, BuiltinTool.code_interpreter - ) From 8de845a96d72f14320cfc3366ccf1850aafbc8f3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 4 Nov 2024 22:10:16 -0800 Subject: [PATCH 9/9] Kill everything from tests/ --- tests/example_custom_tool.py | 45 --------------------- tests/examples/evals-tgi-run.yaml | 66 ------------------------------- tests/examples/inference-run.yaml | 14 ------- tests/examples/local-run.yaml | 50 ----------------------- 4 files changed, 175 deletions(-) delete mode 100644 tests/example_custom_tool.py delete mode 100644 tests/examples/evals-tgi-run.yaml delete mode 100644 tests/examples/inference-run.yaml delete mode 100644 tests/examples/local-run.yaml diff --git a/tests/example_custom_tool.py b/tests/example_custom_tool.py deleted file mode 100644 index f03f18e39..000000000 --- a/tests/example_custom_tool.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Dict - -from llama_models.llama3.api.datatypes import ToolParamDefinition -from llama_stack.tools.custom.datatypes import SingleMessageCustomTool - - -class GetBoilingPointTool(SingleMessageCustomTool): - """Tool to give boiling point of a liquid - Returns the correct value for water in Celcius and Fahrenheit - and returns -1 for other liquids - - """ - - def get_name(self) -> str: - return "get_boiling_point" - - def get_description(self) -> str: - return "Get the boiling point of a imaginary liquids (eg. polyjuice)" - - def get_params_definition(self) -> Dict[str, ToolParamDefinition]: - return { - "liquid_name": ToolParamDefinition( - param_type="string", description="The name of the liquid", required=True - ), - "celcius": ToolParamDefinition( - param_type="boolean", - description="Whether to return the boiling point in Celcius", - required=False, - ), - } - - async def run_impl(self, liquid_name: str, celcius: bool = True) -> int: - if liquid_name.lower() == "polyjuice": - if celcius: - return -100 - else: - return -212 - else: - return -1 diff --git a/tests/examples/evals-tgi-run.yaml b/tests/examples/evals-tgi-run.yaml deleted file mode 100644 index e98047654..000000000 --- a/tests/examples/evals-tgi-run.yaml +++ /dev/null @@ -1,66 +0,0 @@ -version: '2' -built_at: '2024-10-08T17:40:45.325529' -image_name: local -docker_image: null -conda_env: local -apis: -- shields -- safety -- agents -- models -- memory -- memory_banks -- inference -- datasets -- datasetio -- scoring -- eval -providers: - eval: - - provider_id: meta0 - provider_type: meta-reference - config: {} - scoring: - - provider_id: meta0 - provider_type: meta-reference - config: {} - datasetio: - - provider_id: meta0 - provider_type: meta-reference - config: {} - inference: - - provider_id: tgi0 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 - - provider_id: tgi1 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5010 - memory: - - provider_id: meta-reference - provider_type: meta-reference - config: {} - agents: - - provider_id: meta-reference - provider_type: meta-reference - config: - persistence_store: - namespace: null - type: sqlite - db_path: ~/.llama/runtime/kvstore.db - telemetry: - - provider_id: meta-reference - provider_type: meta-reference - config: {} - safety: - - provider_id: meta-reference - provider_type: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M diff --git a/tests/examples/inference-run.yaml b/tests/examples/inference-run.yaml deleted file mode 100644 index 87ab5146b..000000000 --- a/tests/examples/inference-run.yaml +++ /dev/null @@ -1,14 +0,0 @@ -version: '2' -built_at: '2024-10-08T17:40:45.325529' -image_name: local -docker_image: null -conda_env: local -apis: -- models -- inference -providers: - inference: - - provider_id: tgi0 - provider_type: remote::tgi - config: - url: http://127.0.0.1:5009 diff --git a/tests/examples/local-run.yaml b/tests/examples/local-run.yaml deleted file mode 100644 index e12f6e852..000000000 --- a/tests/examples/local-run.yaml +++ /dev/null @@ -1,50 +0,0 @@ -version: '2' -built_at: '2024-10-08T17:40:45.325529' -image_name: local -docker_image: null -conda_env: local -apis: -- shields -- agents -- models -- memory -- memory_banks -- inference -- safety -providers: - inference: - - provider_id: meta-reference - provider_type: meta-reference - config: - model: Llama3.1-8B-Instruct - quantization: null - torch_seed: null - max_seq_len: 4096 - max_batch_size: 1 - safety: - - provider_id: meta-reference - provider_type: meta-reference - config: - llama_guard_shield: - model: Llama-Guard-3-1B - excluded_categories: [] - disable_input_check: false - disable_output_check: false - prompt_guard_shield: - model: Prompt-Guard-86M - memory: - - provider_id: meta-reference - provider_type: meta-reference - config: {} - agents: - - provider_id: meta-reference - provider_type: meta-reference - config: - persistence_store: - namespace: null - type: sqlite - db_path: /home/xiyan/.llama/runtime/kvstore.db - telemetry: - - provider_id: meta-reference - provider_type: meta-reference - config: {}