attempt to finish the implementation started by matt

This commit is contained in:
Ashwin Bharambe 2025-10-27 11:34:40 -07:00
parent 6b585fac00
commit fa4a9ece5b
4 changed files with 66 additions and 44 deletions

View file

@ -95,25 +95,34 @@ providers:
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
storage:
backends:
kv_default:
type: kv_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: ${env.POSTGRES_TABLE_NAME:=llamastack_kvstore}
sql_default:
type: sql_postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
references:
metadata:
backend: kv_default
namespace: registry
inference:
backend: sql_default
table_name: inference_store
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
embedding_dimension: 768
model_id: nomic-embed-text-v1.5
provider_id: sentence-transformers
model_type: embedding
- model_id: ${env.INFERENCE_MODEL}

View file

@ -7,7 +7,7 @@
import os
from typing import Any
from pydantic import BaseModel, Field, SecretStr
from pydantic import BaseModel, Field, SecretStr, field_validator
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.schema_utils import json_schema_type
@ -48,7 +48,7 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
description="A base url for accessing the NVIDIA NIM",
)
api_key: SecretStr | None = Field(
default_factory=lambda: SecretStr(os.getenv("NVIDIA_API_KEY")),
default=None,
description="The NVIDIA API key, only needed of using the hosted service",
)
timeout: int = Field(
@ -60,6 +60,22 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@field_validator("api_key", mode="before")
@classmethod
def _default_api_key_from_env(cls, value: SecretStr | str | None) -> SecretStr | None:
"""Populate the API key from the NVIDIA_API_KEY environment variable when absent."""
if value is None:
env_value = os.getenv("NVIDIA_API_KEY")
return SecretStr(env_value) if env_value else None
if isinstance(value, SecretStr):
return value
if isinstance(value, str):
return SecretStr(value)
return value
@classmethod
def sample_run_config(
cls,

View file

@ -23,7 +23,6 @@ class RunpodInferenceAdapter(OpenAIMixin):
"""
config: RunpodImplConfig
provider_data_api_key_field: str = "runpod_api_token"
def get_api_key(self) -> str:

View file

@ -7,7 +7,7 @@ required-version = ">=0.7.0"
[project]
name = "llama_stack"
version = "0.2.23"
version = "0.3.0"
authors = [{ name = "Meta Llama", email = "llama-oss@meta.com" }]
description = "Llama Stack"
readme = "README.md"
@ -25,17 +25,17 @@ classifiers = [
]
dependencies = [
"aiohttp",
"fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client
"databricks-sdk",
"fastapi>=0.115.0,<1.0", # server
"fire", # for MCP in LLS client
"httpx",
"huggingface-hub>=0.34.0,<1.0",
"jinja2>=3.1.6",
"jsonschema",
"llama-stack-client>=0.2.23",
"openai>=1.107", # for expires_after support
"llama-stack-client>=0.3.0",
"openai>=1.107", # for expires_after support
"prompt-toolkit",
"python-dotenv",
"python-jose[cryptography]",
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
"pydantic>=2.11.9",
"rich",
"starlette",
@ -43,20 +43,20 @@ dependencies = [
"tiktoken",
"pillow",
"h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server
"opentelemetry-sdk>=1.30.0", # server
"python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server
"opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
"aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
]
[project.optional-dependencies]
ui = [
"streamlit",
"pandas",
"llama-stack-client>=0.2.23",
"llama-stack-client>=0.3.0",
"streamlit-option-menu",
]
@ -68,14 +68,14 @@ dev = [
"pytest-cov",
"pytest-html",
"pytest-json-report",
"pytest-socket", # For blocking network access in unit tests
"nbval", # For notebook testing
"pytest-socket", # For blocking network access in unit tests
"nbval", # For notebook testing
"black",
"ruff",
"types-requests",
"types-setuptools",
"pre-commit",
"ruamel.yaml", # needed for openapi generator
"ruamel.yaml", # needed for openapi generator
]
# These are the dependencies required for running unit tests.
unit = [
@ -122,6 +122,8 @@ test = [
"sqlalchemy",
"sqlalchemy[asyncio]>=2.0.41",
"requests",
"chromadb>=1.0.15",
"qdrant-client",
"pymilvus>=2.6.1",
"milvus-lite>=2.5.0",
"weaviate-client>=4.16.4",
@ -146,9 +148,7 @@ docs = [
"requests",
]
codegen = ["rich", "pydantic>=2.11.9", "jinja2>=3.1.6"]
benchmark = [
"locust>=2.39.1",
]
benchmark = ["locust>=2.39.1"]
[project.urls]
Homepage = "https://github.com/llamastack/llama-stack"
@ -247,7 +247,6 @@ follow_imports = "silent"
# to exclude the entire directory.
exclude = [
# As we fix more and more of these, we should remove them from the list
"^llama_stack/cli/download\\.py$",
"^llama_stack.core/build\\.py$",
"^llama_stack.core/client\\.py$",
"^llama_stack.core/request_headers\\.py$",
@ -337,6 +336,5 @@ classmethod-decorators = ["classmethod", "pydantic.field_validator"]
[tool.pytest.ini_options]
addopts = ["--durations=10"]
asyncio_mode = "auto"
markers = [
"allow_network: Allow network access for specific unit tests",
]
markers = ["allow_network: Allow network access for specific unit tests"]
filterwarnings = "ignore::DeprecationWarning"