mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: ability to execute external providers (#1672)
# What does this PR do? Providers that live outside of the llama-stack codebase are now supported. A new property `external_providers_dir` has been added to the main config and can be configured as follow: ``` external_providers_dir: /etc/llama-stack/providers.d/ ``` Where the expected structure is: ``` providers.d/ inference/ custom_ollama.yaml vllm.yaml vector_io/ qdrant.yaml ``` Where `custom_ollama.yaml` is: ``` adapter: adapter_type: custom_ollama pip_packages: ["ollama", "aiohttp"] config_class: llama_stack_ollama_provider.config.OllamaImplConfig module: llama_stack_ollama_provider api_dependencies: [] optional_api_dependencies: [] ``` Obviously the package must be installed on the system, here is the `llama_stack_ollama_provider` example: ``` $ uv pip show llama-stack-ollama-provider Using Python 3.10.16 environment at: /Users/leseb/Documents/AI/llama-stack/.venv Name: llama-stack-ollama-provider Version: 0.1.0 Location: /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.10/site-packages Editable project location: /private/var/folders/mq/rnm5w_7s2d3fxmtkx02knvhm0000gn/T/tmp.ZBHU5Ezxg4/ollama/llama-stack-ollama-provider Requires: Required-by: ``` Closes: https://github.com/meta-llama/llama-stack/issues/658 Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
parent
45e210fd0c
commit
389767010b
12 changed files with 875 additions and 7 deletions
|
@ -0,0 +1,3 @@
|
|||
# Ollama external provider for Llama Stack
|
||||
|
||||
Template code to create a new external provider for Llama Stack.
|
|
@ -0,0 +1,7 @@
|
|||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
|
@ -0,0 +1,44 @@
|
|||
[project]
|
||||
dependencies = [
|
||||
"llama-stack",
|
||||
"pydantic",
|
||||
"ollama",
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"autoevals",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
"chromadb-client",
|
||||
"datasets",
|
||||
"faiss-cpu",
|
||||
"fastapi",
|
||||
"fire",
|
||||
"httpx",
|
||||
"matplotlib",
|
||||
"mcp",
|
||||
"nltk",
|
||||
"numpy",
|
||||
"openai",
|
||||
"opentelemetry-exporter-otlp-proto-http",
|
||||
"opentelemetry-sdk",
|
||||
"pandas",
|
||||
"pillow",
|
||||
"psycopg2-binary",
|
||||
"pymongo",
|
||||
"pypdf",
|
||||
"redis",
|
||||
"requests",
|
||||
"scikit-learn",
|
||||
"scipy",
|
||||
"sentencepiece",
|
||||
"tqdm",
|
||||
"transformers",
|
||||
"tree_sitter",
|
||||
"uvicorn",
|
||||
]
|
||||
|
||||
name = "llama-stack-provider-ollama"
|
||||
version = "0.1.0"
|
||||
description = "External provider for Ollama using the Llama Stack API"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
135
tests/external-provider/llama-stack-provider-ollama/run.yaml
Normal file
135
tests/external-provider/llama-stack-provider-ollama/run.yaml
Normal file
|
@ -0,0 +1,135 @@
|
|||
version: '2'
|
||||
image_name: ollama
|
||||
apis:
|
||||
- agents
|
||||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
- tool_runtime
|
||||
- vector_io
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: custom_ollama
|
||||
provider_type: remote::custom_ollama
|
||||
config:
|
||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||
vector_io:
|
||||
- provider_id: faiss
|
||||
provider_type: inline::faiss
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||
safety:
|
||||
- provider_id: llama-guard
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
excluded_categories: []
|
||||
agents:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
eval:
|
||||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
||||
datasetio:
|
||||
- provider_id: huggingface
|
||||
provider_type: remote::huggingface
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/huggingface_datasetio.db
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
config:
|
||||
kvstore:
|
||||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||
scoring:
|
||||
- provider_id: basic
|
||||
provider_type: inline::basic
|
||||
config: {}
|
||||
- provider_id: llm-as-judge
|
||||
provider_type: inline::llm-as-judge
|
||||
config: {}
|
||||
- provider_id: braintrust
|
||||
provider_type: inline::braintrust
|
||||
config:
|
||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||
tool_runtime:
|
||||
- provider_id: brave-search
|
||||
provider_type: remote::brave-search
|
||||
config:
|
||||
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: tavily-search
|
||||
provider_type: remote::tavily-search
|
||||
config:
|
||||
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||
max_results: 3
|
||||
- provider_id: code-interpreter
|
||||
provider_type: inline::code-interpreter
|
||||
config: {}
|
||||
- provider_id: rag-runtime
|
||||
provider_type: inline::rag-runtime
|
||||
config: {}
|
||||
- provider_id: model-context-protocol
|
||||
provider_type: remote::model-context-protocol
|
||||
config: {}
|
||||
- provider_id: wolfram-alpha
|
||||
provider_type: remote::wolfram-alpha
|
||||
config:
|
||||
api_key: ${env.WOLFRAM_ALPHA_API_KEY:}
|
||||
metadata_store:
|
||||
type: sqlite
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/registry.db
|
||||
models:
|
||||
- metadata: {}
|
||||
model_id: ${env.INFERENCE_MODEL}
|
||||
provider_id: custom_ollama
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 384
|
||||
model_id: all-MiniLM-L6-v2
|
||||
provider_id: custom_ollama
|
||||
provider_model_id: all-minilm:latest
|
||||
model_type: embedding
|
||||
shields: []
|
||||
vector_dbs: []
|
||||
datasets: []
|
||||
scoring_fns: []
|
||||
benchmarks: []
|
||||
tool_groups:
|
||||
- toolgroup_id: builtin::websearch
|
||||
provider_id: tavily-search
|
||||
- toolgroup_id: builtin::rag
|
||||
provider_id: rag-runtime
|
||||
- toolgroup_id: builtin::code_interpreter
|
||||
provider_id: code-interpreter
|
||||
- toolgroup_id: builtin::wolfram_alpha
|
||||
provider_id: wolfram-alpha
|
||||
server:
|
||||
port: 8321
|
||||
external_providers_dir: /tmp/providers.d
|
223
tests/unit/distribution/test_distribution.py
Normal file
223
tests/unit/distribution/test_distribution.py
Normal file
|
@ -0,0 +1,223 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, Provider, StackRunConfig
|
||||
from llama_stack.distribution.distribution import get_provider_registry
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
class SampleConfig(BaseModel):
|
||||
foo: str = Field(
|
||||
default="bar",
|
||||
description="foo",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]:
|
||||
return {
|
||||
"foo": "baz",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_providers():
|
||||
"""Mock the available_providers function to return test providers."""
|
||||
with patch("llama_stack.providers.registry.inference.available_providers") as mock:
|
||||
mock.return_value = [
|
||||
ProviderSpec(
|
||||
provider_type="test_provider",
|
||||
api=Api.inference,
|
||||
adapter_type="test_adapter",
|
||||
config_class="test_provider.config.TestProviderConfig",
|
||||
)
|
||||
]
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_config(tmp_path):
|
||||
"""Create a base StackRunConfig with common settings."""
|
||||
return StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
external_providers_dir=str(tmp_path),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_spec_yaml():
|
||||
"""Common provider spec YAML for testing."""
|
||||
return """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
module: test_provider
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def inline_provider_spec_yaml():
|
||||
"""Common inline provider spec YAML for testing."""
|
||||
return """
|
||||
module: test_provider
|
||||
config_class: test_provider.config.TestProviderConfig
|
||||
pip_packages:
|
||||
- test-package
|
||||
api_dependencies:
|
||||
- safety
|
||||
optional_api_dependencies:
|
||||
- vector_io
|
||||
provider_data_validator: test_provider.validator.TestValidator
|
||||
container_image: test-image:latest
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_directories(tmp_path):
|
||||
"""Create the API directory structure for testing."""
|
||||
# Create remote provider directory
|
||||
remote_inference_dir = tmp_path / "remote" / "inference"
|
||||
remote_inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create inline provider directory
|
||||
inline_inference_dir = tmp_path / "inline" / "inference"
|
||||
inline_inference_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return remote_inference_dir, inline_inference_dir
|
||||
|
||||
|
||||
class TestProviderRegistry:
|
||||
"""Test suite for provider registry functionality."""
|
||||
|
||||
def test_builtin_providers(self, mock_providers):
|
||||
"""Test loading built-in providers."""
|
||||
registry = get_provider_registry(None)
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "test_provider" in registry[Api.inference]
|
||||
assert registry[Api.inference]["test_provider"].provider_type == "test_provider"
|
||||
assert registry[Api.inference]["test_provider"].api == Api.inference
|
||||
|
||||
def test_external_remote_providers(self, api_directories, mock_providers, base_config, provider_spec_yaml):
|
||||
"""Test loading external remote providers from YAML files."""
|
||||
remote_dir, _ = api_directories
|
||||
with open(remote_dir / "test_provider.yaml", "w") as f:
|
||||
f.write(provider_spec_yaml)
|
||||
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 2
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "remote::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["remote::test_provider"]
|
||||
assert provider.adapter.adapter_type == "test_provider"
|
||||
assert provider.adapter.module == "test_provider"
|
||||
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert Api.safety in provider.api_dependencies
|
||||
|
||||
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
|
||||
"""Test loading external inline providers from YAML files."""
|
||||
_, inline_dir = api_directories
|
||||
with open(inline_dir / "test_provider.yaml", "w") as f:
|
||||
f.write(inline_provider_spec_yaml)
|
||||
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 2
|
||||
|
||||
assert Api.inference in registry
|
||||
assert "inline::test_provider" in registry[Api.inference]
|
||||
provider = registry[Api.inference]["inline::test_provider"]
|
||||
assert provider.provider_type == "inline::test_provider"
|
||||
assert provider.module == "test_provider"
|
||||
assert provider.config_class == "test_provider.config.TestProviderConfig"
|
||||
assert provider.pip_packages == ["test-package"]
|
||||
assert Api.safety in provider.api_dependencies
|
||||
assert Api.vector_io in provider.optional_api_dependencies
|
||||
assert provider.provider_data_validator == "test_provider.validator.TestValidator"
|
||||
assert provider.container_image == "test-image:latest"
|
||||
|
||||
def test_invalid_yaml(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of invalid YAML files."""
|
||||
remote_dir, inline_dir = api_directories
|
||||
with open(remote_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: -")
|
||||
with open(inline_dir / "invalid.yaml", "w") as f:
|
||||
f.write("invalid: yaml: content: -")
|
||||
|
||||
with pytest.raises(yaml.YAMLError):
|
||||
get_provider_registry(base_config)
|
||||
|
||||
def test_missing_directory(self, mock_providers):
|
||||
"""Test handling of missing external providers directory."""
|
||||
config = StackRunConfig(
|
||||
image_name="test_image",
|
||||
providers={
|
||||
"inference": [
|
||||
Provider(
|
||||
provider_id="sample_provider",
|
||||
provider_type="sample",
|
||||
config=SampleConfig.sample_run_config(),
|
||||
)
|
||||
]
|
||||
},
|
||||
external_providers_dir="/nonexistent/dir",
|
||||
)
|
||||
with pytest.raises(FileNotFoundError):
|
||||
get_provider_registry(config)
|
||||
|
||||
def test_empty_api_directory(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of empty API directory."""
|
||||
registry = get_provider_registry(base_config)
|
||||
assert len(registry[Api.inference]) == 1 # Only built-in provider
|
||||
|
||||
def test_malformed_remote_provider_spec(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of malformed remote provider spec (missing required fields)."""
|
||||
remote_dir, _ = api_directories
|
||||
malformed_spec = """
|
||||
adapter:
|
||||
adapter_type: test_provider
|
||||
# Missing required fields
|
||||
api_dependencies:
|
||||
- safety
|
||||
"""
|
||||
with open(remote_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
get_provider_registry(base_config)
|
||||
|
||||
def test_malformed_inline_provider_spec(self, api_directories, mock_providers, base_config):
|
||||
"""Test handling of malformed inline provider spec (missing required fields)."""
|
||||
_, inline_dir = api_directories
|
||||
malformed_spec = """
|
||||
module: test_provider
|
||||
# Missing required config_class
|
||||
pip_packages:
|
||||
- test-package
|
||||
"""
|
||||
with open(inline_dir / "malformed.yaml", "w") as f:
|
||||
f.write(malformed_spec)
|
||||
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
get_provider_registry(base_config)
|
||||
assert "config_class" in str(exc_info.value)
|
Loading…
Add table
Add a link
Reference in a new issue