mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-15 22:47:59 +00:00
Merge branch 'main' into kill_configure
This commit is contained in:
commit
47d91b10fb
29 changed files with 119 additions and 1463 deletions
|
@ -23,7 +23,7 @@ class ShieldDef(BaseModel):
|
|||
identifier: str = Field(
|
||||
description="A unique identifier for the shield type",
|
||||
)
|
||||
type: str = Field(
|
||||
shield_type: str = Field(
|
||||
description="The type of shield this is; the value is one of the ShieldType enum"
|
||||
)
|
||||
params: Dict[str, Any] = Field(
|
||||
|
|
|
@ -25,6 +25,7 @@ from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
|
|||
# These are the dependencies needed by the distribution server.
|
||||
# `llama-stack` is automatically installed by the installation script.
|
||||
SERVER_DEPENDENCIES = [
|
||||
"aiosqlite",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
|
|
|
@ -83,6 +83,7 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
j = response.json()
|
||||
if j is None:
|
||||
return None
|
||||
# print(f"({protocol.__name__}) Returning {j}, type {return_type}")
|
||||
return parse_obj_as(return_type, j)
|
||||
|
||||
async def _call_streaming(self, method_name: str, *args, **kwargs) -> Any:
|
||||
|
@ -102,14 +103,15 @@ def create_api_client_class(protocol, additional_protocol) -> Type:
|
|||
if line.startswith("data:"):
|
||||
data = line[len("data: ") :]
|
||||
try:
|
||||
data = json.loads(data)
|
||||
if "error" in data:
|
||||
cprint(data, "red")
|
||||
continue
|
||||
|
||||
yield parse_obj_as(return_type, json.loads(data))
|
||||
yield parse_obj_as(return_type, data)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
print(f"Error with parsing or validation: {e}")
|
||||
print(data)
|
||||
|
||||
def httpx_request_params(self, method_name: str, *args, **kwargs) -> dict:
|
||||
webmethod, sig = self.routes[method_name]
|
||||
|
|
|
@ -178,16 +178,17 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
await register_object_with_provider(obj, p)
|
||||
await self.dist_registry.register(obj)
|
||||
|
||||
async def get_all_with_type(self, type: str) -> List[RoutableObjectWithProvider]:
|
||||
objs = await self.dist_registry.get_all()
|
||||
return [obj for obj in objs if obj.type == type]
|
||||
|
||||
|
||||
class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||
async def list_models(self) -> List[ModelDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
return await self.get_all_with_type("model")
|
||||
|
||||
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_model(self, model: ModelDefWithProvider) -> None:
|
||||
await self.register_object(model)
|
||||
|
@ -195,13 +196,10 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
|
||||
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
||||
async def list_shields(self) -> List[ShieldDef]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
return await self.get_all_with_type("shield")
|
||||
|
||||
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]:
|
||||
return self.get_object_by_identifier(shield_type)
|
||||
return await self.get_object_by_identifier(shield_type)
|
||||
|
||||
async def register_shield(self, shield: ShieldDefWithProvider) -> None:
|
||||
await self.register_object(shield)
|
||||
|
@ -209,15 +207,12 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):
|
|||
|
||||
class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
return await self.get_all_with_type("memory_bank")
|
||||
|
||||
async def get_memory_bank(
|
||||
self, identifier: str
|
||||
) -> Optional[MemoryBankDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
return await self.get_object_by_identifier(identifier)
|
||||
|
||||
async def register_memory_bank(
|
||||
self, memory_bank: MemoryBankDefWithProvider
|
||||
|
@ -227,15 +222,12 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
|||
|
||||
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
return await self.get_all_with_type("dataset")
|
||||
|
||||
async def get_dataset(
|
||||
self, dataset_identifier: str
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
return self.get_object_by_identifier(dataset_identifier)
|
||||
return await self.get_object_by_identifier(dataset_identifier)
|
||||
|
||||
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
|
||||
await self.register_object(dataset_def)
|
||||
|
@ -243,15 +235,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
return await self.get_all_with_type("scoring_function")
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFnDefWithProvider]:
|
||||
return self.get_object_by_identifier(name)
|
||||
return await self.get_object_by_identifier(name)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFnDefWithProvider
|
||||
|
|
|
@ -134,7 +134,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
stream = _to_async_generator()
|
||||
async for chunk in process_chat_completion_stream_response(
|
||||
request, stream, self.formatter
|
||||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ class TogetherSafetyImpl(Safety, NeedsRequestProviderData, ShieldsProtocolPrivat
|
|||
return [
|
||||
ShieldDef(
|
||||
identifier=ShieldType.llama_guard.value,
|
||||
type=ShieldType.llama_guard.value,
|
||||
shield_type=ShieldType.llama_guard.value,
|
||||
params={},
|
||||
)
|
||||
]
|
||||
|
|
|
@ -25,8 +25,8 @@ class MetaReferenceCodeScannerSafetyImpl(Safety):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: ShieldDef) -> None:
|
||||
if shield.type != ShieldType.code_scanner.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.type}")
|
||||
if shield.shield_type != ShieldType.code_scanner.value:
|
||||
raise ValueError(f"Unsupported safety shield type: {shield.shield_type}")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
@ -49,7 +49,7 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
return [
|
||||
ShieldDef(
|
||||
identifier=shield_type,
|
||||
type=shield_type,
|
||||
shield_type=shield_type,
|
||||
params={},
|
||||
)
|
||||
for shield_type in self.available_shields
|
||||
|
@ -92,14 +92,14 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
return RunShieldResponse(violation=violation)
|
||||
|
||||
def get_shield_impl(self, shield: ShieldDef) -> ShieldBase:
|
||||
if shield.type == ShieldType.llama_guard.value:
|
||||
if shield.shield_type == ShieldType.llama_guard.value:
|
||||
cfg = self.config.llama_guard_shield
|
||||
return LlamaGuardShield(
|
||||
model=cfg.model,
|
||||
inference_api=self.inference_api,
|
||||
excluded_categories=cfg.excluded_categories,
|
||||
)
|
||||
elif shield.type == ShieldType.prompt_guard.value:
|
||||
elif shield.shield_type == ShieldType.prompt_guard.value:
|
||||
model_dir = model_local_dir(PROMPT_GUARD_MODEL)
|
||||
subtype = shield.params.get("prompt_guard_type", "injection")
|
||||
if subtype == "injection":
|
||||
|
@ -109,4 +109,4 @@ class MetaReferenceSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
else:
|
||||
raise ValueError(f"Unknown prompt guard type: {subtype}")
|
||||
else:
|
||||
raise ValueError(f"Unknown shield type: {shield.type}")
|
||||
raise ValueError(f"Unknown shield type: {shield.shield_type}")
|
||||
|
|
|
@ -46,11 +46,21 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "remote",
|
||||
"safety": "remote",
|
||||
"memory": "remote",
|
||||
"agents": "remote",
|
||||
},
|
||||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together"]:
|
||||
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
|
|
|
@ -18,7 +18,12 @@ from llama_stack.providers.impls.meta_reference.agents import (
|
|||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def agents_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -40,7 +45,7 @@ def agents_meta_reference() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
AGENTS_FIXTURES = ["meta_reference"]
|
||||
AGENTS_FIXTURES = ["meta_reference", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -109,7 +109,6 @@ class TestAgents:
|
|||
turn_response = [
|
||||
chunk async for chunk in await agents_impl.create_agent_turn(**turn_request)
|
||||
]
|
||||
|
||||
assert len(turn_response) > 0
|
||||
check_event_types(turn_response)
|
||||
|
||||
|
|
|
@ -14,6 +14,9 @@ from pydantic import BaseModel
|
|||
from termcolor import colored
|
||||
|
||||
from llama_stack.distribution.datatypes import Provider
|
||||
from llama_stack.providers.datatypes import RemoteProviderConfig
|
||||
|
||||
from .env import get_env_or_fail
|
||||
|
||||
|
||||
class ProviderFixture(BaseModel):
|
||||
|
@ -21,6 +24,21 @@ class ProviderFixture(BaseModel):
|
|||
provider_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
def remote_stack_fixture() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
providers=[
|
||||
Provider(
|
||||
provider_id="remote",
|
||||
provider_type="remote",
|
||||
config=RemoteProviderConfig(
|
||||
host=get_env_or_fail("REMOTE_STACK_HOST"),
|
||||
port=int(get_env_or_fail("REMOTE_STACK_PORT")),
|
||||
).model_dump(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
config.option.tbstyle = "short"
|
||||
config.option.disable_warnings = True
|
||||
|
|
|
@ -18,7 +18,7 @@ from llama_stack.providers.impls.meta_reference.inference import (
|
|||
MetaReferenceInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from ..conftest import ProviderFixture
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
|
@ -29,6 +29,11 @@ def inference_model(request):
|
|||
return request.config.getoption("--inference-model", None)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||
inference_model = (
|
||||
|
@ -104,7 +109,7 @@ def inference_together() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together"]
|
||||
INFERENCE_FIXTURES = ["meta_reference", "ollama", "fireworks", "together", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -15,10 +15,15 @@ from llama_stack.providers.adapters.memory.weaviate import WeaviateConfig
|
|||
from llama_stack.providers.impls.meta_reference.memory import FaissImplConfig
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
from ..conftest import ProviderFixture
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def memory_meta_reference() -> ProviderFixture:
|
||||
return ProviderFixture(
|
||||
|
@ -68,7 +73,7 @@ def memory_weaviate() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate"]
|
||||
MEMORY_FIXTURES = ["meta_reference", "pgvector", "weaviate", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -6,6 +6,7 @@
|
|||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
@ -16,6 +17,8 @@ from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
|
|||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.distribution.request_headers import set_request_provider_data
|
||||
from llama_stack.distribution.resolver import resolve_impls
|
||||
from llama_stack.distribution.store import CachedDiskDistributionRegistry
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, SqliteKVStoreConfig
|
||||
|
||||
|
||||
async def resolve_impls_for_test_v2(
|
||||
|
@ -30,7 +33,11 @@ async def resolve_impls_for_test_v2(
|
|||
providers=providers,
|
||||
)
|
||||
run_config = parse_and_maybe_upgrade_config(run_config)
|
||||
impls = await resolve_impls(run_config, get_provider_registry())
|
||||
|
||||
sqlite_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
|
||||
dist_kvstore = await kvstore_impl(SqliteKVStoreConfig(db_path=sqlite_file.name))
|
||||
dist_registry = CachedDiskDistributionRegistry(dist_kvstore)
|
||||
impls = await resolve_impls(run_config, get_provider_registry(), dist_registry)
|
||||
|
||||
if provider_data:
|
||||
set_request_provider_data(
|
||||
|
|
|
@ -37,11 +37,19 @@ DEFAULT_PROVIDER_COMBINATIONS = [
|
|||
id="together",
|
||||
marks=pytest.mark.together,
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"inference": "remote",
|
||||
"safety": "remote",
|
||||
},
|
||||
id="remote",
|
||||
marks=pytest.mark.remote,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
for mark in ["meta_reference", "ollama", "together"]:
|
||||
for mark in ["meta_reference", "ollama", "together", "remote"]:
|
||||
config.addinivalue_line(
|
||||
"markers",
|
||||
f"{mark}: marks tests as {mark} specific",
|
||||
|
|
|
@ -16,10 +16,15 @@ from llama_stack.providers.impls.meta_reference.safety import (
|
|||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
|
||||
|
||||
from ..conftest import ProviderFixture
|
||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||
from ..env import get_env_or_fail
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_remote() -> ProviderFixture:
|
||||
return remote_stack_fixture()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def safety_model(request):
|
||||
if hasattr(request, "param"):
|
||||
|
@ -60,7 +65,7 @@ def safety_together() -> ProviderFixture:
|
|||
)
|
||||
|
||||
|
||||
SAFETY_FIXTURES = ["meta_reference", "together"]
|
||||
SAFETY_FIXTURES = ["meta_reference", "together", "remote"]
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
|
|
|
@ -27,7 +27,7 @@ class TestSafety:
|
|||
|
||||
for shield in response:
|
||||
assert isinstance(shield, ShieldDefWithProvider)
|
||||
assert shield.type in [v.value for v in ShieldType]
|
||||
assert shield.shield_type in [v.value for v in ShieldType]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_shield(self, safety_stack):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue