attempt to finish the implementation started by matt

This commit is contained in:
Ashwin Bharambe 2025-10-27 11:34:40 -07:00
parent 6b585fac00
commit fa4a9ece5b
4 changed files with 66 additions and 44 deletions

View file

@ -95,25 +95,34 @@ providers:
- provider_id: model-context-protocol - provider_id: model-context-protocol
provider_type: remote::model-context-protocol provider_type: remote::model-context-protocol
config: {} config: {}
metadata_store: storage:
type: postgres backends:
host: ${env.POSTGRES_HOST:=localhost} kv_default:
port: ${env.POSTGRES_PORT:=5432} type: kv_postgres
db: ${env.POSTGRES_DB:=llamastack} host: ${env.POSTGRES_HOST:=localhost}
user: ${env.POSTGRES_USER:=llamastack} port: ${env.POSTGRES_PORT:=5432}
password: ${env.POSTGRES_PASSWORD:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
table_name: llamastack_kvstore user: ${env.POSTGRES_USER:=llamastack}
inference_store: password: ${env.POSTGRES_PASSWORD:=llamastack}
type: postgres table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
host: ${env.POSTGRES_HOST:=localhost} sql_default:
port: ${env.POSTGRES_PORT:=5432} type: sql_postgres
db: ${env.POSTGRES_DB:=llamastack} host: ${env.POSTGRES_HOST:=localhost}
user: ${env.POSTGRES_USER:=llamastack} port: ${env.POSTGRES_PORT:=5432}
password: ${env.POSTGRES_PASSWORD:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
references:
metadata:
backend: kv_default
namespace: registry
inference:
backend: sql_default
table_name: inference_store
models: models:
- metadata: - metadata:
embedding_dimension: 384 embedding_dimension: 768
model_id: all-MiniLM-L6-v2 model_id: nomic-embed-text-v1.5
provider_id: sentence-transformers provider_id: sentence-transformers
model_type: embedding model_type: embedding
- model_id: ${env.INFERENCE_MODEL} - model_id: ${env.INFERENCE_MODEL}

View file

@ -7,7 +7,7 @@
import os import os
from typing import Any from typing import Any
from pydantic import BaseModel, Field, SecretStr from pydantic import BaseModel, Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -48,7 +48,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
description="A base url for accessing the NVIDIA NIM", description="A base url for accessing the NVIDIA NIM",
) )
api_key: SecretStr | None = Field( api_key: SecretStr | None = Field(
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")), default=None,
description="The NVIDIA API key, only needed of using the hosted service", description="The NVIDIA API key, only needed of using the hosted service",
) )
timeout: int = Field( timeout: int = Field(
@ -60,6 +60,22 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
description="When set to false, the API version will not be appended to the base_url. By default, it is true.", description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
) )
@field_validator("api_key", mode="before")
@classmethod
def _default_api_key_from_env(cls, value: SecretStr | str | None) -> SecretStr | None:
"""Populate the API key from the NVIDIA_API_KEY environment variable when absent."""
if value is None:
env_value = os.getenv("NVIDIA_API_KEY")
return SecretStr(env_value) if env_value else None
if isinstance(value, SecretStr):
return value
if isinstance(value, str):
return SecretStr(value)
return value
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,

View file

@ -23,7 +23,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
""" """
config: RunpodImplConfig config: RunpodImplConfig
provider_data_api_key_field: str = "runpod_api_token" provider_data_api_key_field: str = "runpod_api_token"
def get_api_key(self) -> str: def get_api_key(self) -> str:

View file

@ -7,7 +7,7 @@ required-version = ">=0.7.0"
[project] [project]
name = "llama_stack" name = "llama_stack"
version = "0.2.23" version = "0.3.0"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }] authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack" description = "Llama Stack"
readme = "README.md" readme = "README.md"
@ -25,17 +25,17 @@ classifiers = [
] ]
dependencies = [ dependencies = [
"aiohttp", "aiohttp",
"fastapi>=0.115.0,<1.0", # server "databricks-sdk",
"fire", # for MCP in LLS client "fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client
"httpx", "httpx",
"huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6", "jinja2>=3.1.6",
"jsonschema", "jsonschema",
"llama-stack-client>=0.2.23", "llama-stack-client>=0.3.0",
"openai>=1.107", # for expires_after support "openai>=1.107", # for expires_after support
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose[cryptography]", "pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
"pydantic>=2.11.9", "pydantic>=2.11.9",
"rich", "rich",
"starlette", "starlette",
@ -43,20 +43,20 @@ dependencies = [
"tiktoken", "tiktoken",
"pillow", "pillow",
"h11>=0.16.0", "h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form "python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server "uvicorn>=0.34.0", # server
"opentelemetry-sdk>=1.30.0", # server "opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations "sqlalchemy[asyncio]>=2.0.41", # server - for conversations
] ]
[project.optional-dependencies] [project.optional-dependencies]
ui = [ ui = [
"streamlit", "streamlit",
"pandas", "pandas",
"llama-stack-client>=0.2.23", "llama-stack-client>=0.3.0",
"streamlit-option-menu", "streamlit-option-menu",
] ]
@ -68,14 +68,14 @@ dev = [
"pytest-cov", "pytest-cov",
"pytest-html", "pytest-html",
"pytest-json-report", "pytest-json-report",
"pytest-socket", # For blocking network access in unit tests "pytest-socket", # For blocking network access in unit tests
"nbval", # For notebook testing "nbval", # For notebook testing
"black", "black",
"ruff", "ruff",
"types-requests", "types-requests",
"types-setuptools", "types-setuptools",
"pre-commit", "pre-commit",
"ruamel.yaml", # needed for openapi generator "ruamel.yaml", # needed for openapi generator
] ]
# These are the dependencies required for running unit tests. # These are the dependencies required for running unit tests.
unit = [ unit = [
@ -122,6 +122,8 @@ test = [
"sqlalchemy", "sqlalchemy",
"sqlalchemy[asyncio]>=2.0.41", "sqlalchemy[asyncio]>=2.0.41",
"requests", "requests",
"chromadb>=1.0.15",
"qdrant-client",
"pymilvus>=2.6.1", "pymilvus>=2.6.1",
"milvus-lite>=2.5.0", "milvus-lite>=2.5.0",
"weaviate-client>=4.16.4", "weaviate-client>=4.16.4",
@ -146,9 +148,7 @@ docs = [
"requests", "requests",
] ]
codegen = ["rich", "pydantic>=2.11.9", "jinja2>=3.1.6"] codegen = ["rich", "pydantic>=2.11.9", "jinja2>=3.1.6"]
benchmark = [ benchmark = ["locust>=2.39.1"]
"locust>=2.39.1",
]
[project.urls] [project.urls]
Homepage = "https://github.com/llamastack/llama-stack" Homepage = "https://github.com/llamastack/llama-stack"
@ -247,7 +247,6 @@ follow_imports = "silent"
# to exclude the entire directory. # to exclude the entire directory.
exclude = [ exclude = [
# As we fix more and more of these, we should remove them from the list # As we fix more and more of these, we should remove them from the list
"^llama_stack/cli/download\\.py$",
"^llama_stack.core/build\\.py$", "^llama_stack.core/build\\.py$",
"^llama_stack.core/client\\.py$", "^llama_stack.core/client\\.py$",
"^llama_stack.core/request_headers\\.py$", "^llama_stack.core/request_headers\\.py$",
@ -337,6 +336,5 @@ classmethod-decorators = ["classmethod", "pydantic.field_validator"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
addopts = ["--durations=10"] addopts = ["--durations=10"]
asyncio_mode = "auto" asyncio_mode = "auto"
markers = [ markers = ["allow_network: Allow network access for specific unit tests"]
"allow_network: Allow network access for specific unit tests", filterwarnings = "ignore::DeprecationWarning"
]