mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-26 22:19:49 +00:00
feat(starter)!: simplify starter distro; litellm model registry changes (#2916)
This commit is contained in:
parent
3344d8a9e5
commit
9583f468f8
64 changed files with 2027 additions and 4092 deletions
12
.github/workflows/integration-tests.yml
vendored
12
.github/workflows/integration-tests.yml
vendored
|
@ -117,17 +117,13 @@ jobs:
|
|||
|
||||
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
|
||||
if [ "${{ matrix.provider }}" == "ollama" ]; then
|
||||
export ENABLE_OLLAMA="ollama"
|
||||
export OLLAMA_URL="http://0.0.0.0:11434"
|
||||
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16"
|
||||
export TEXT_MODEL=ollama/$OLLAMA_INFERENCE_MODEL
|
||||
export SAFETY_MODEL="llama-guard3:1b"
|
||||
EXTRA_PARAMS="--safety-shield=$SAFETY_MODEL"
|
||||
export TEXT_MODEL=ollama/llama3.2:3b-instruct-fp16
|
||||
export SAFETY_MODEL="ollama/llama-guard3:1b"
|
||||
EXTRA_PARAMS="--safety-shield=llama-guard"
|
||||
else
|
||||
export ENABLE_VLLM="vllm"
|
||||
export VLLM_URL="http://localhost:8000/v1"
|
||||
export VLLM_INFERENCE_MODEL="meta-llama/Llama-3.2-1B-Instruct"
|
||||
export TEXT_MODEL=vllm/$VLLM_INFERENCE_MODEL
|
||||
export TEXT_MODEL=vllm/meta-llama/Llama-3.2-1B-Instruct
|
||||
# TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently
|
||||
EXTRA_PARAMS=
|
||||
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
|
||||
|
|
|
@ -249,12 +249,6 @@
|
|||
],
|
||||
"source": [
|
||||
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"ENABLE_OLLAMA\"] = \"ollama\"\n",
|
||||
"os.environ[\"OLLAMA_INFERENCE_MODEL\"] = \"llama3.2:3b\"\n",
|
||||
"os.environ[\"OLLAMA_EMBEDDING_MODEL\"] = \"all-minilm:l6-v2\"\n",
|
||||
"os.environ[\"OLLAMA_EMBEDDING_DIMENSION\"] = \"384\"\n",
|
||||
"\n",
|
||||
"vector_db_id = \"my_demo_vector_db\"\n",
|
||||
"client = LlamaStackClient(base_url=\"http://0.0.0.0:8321\")\n",
|
||||
|
|
|
@ -40,16 +40,16 @@ The following environment variables can be configured:
|
|||
|
||||
The following models are available by default:
|
||||
|
||||
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `meta/llama3-8b-instruct `
|
||||
- `meta/llama3-70b-instruct `
|
||||
- `meta/llama-3.1-8b-instruct `
|
||||
- `meta/llama-3.1-70b-instruct `
|
||||
- `meta/llama-3.1-405b-instruct `
|
||||
- `meta/llama-3.2-1b-instruct `
|
||||
- `meta/llama-3.2-3b-instruct `
|
||||
- `meta/llama-3.2-11b-vision-instruct `
|
||||
- `meta/llama-3.2-90b-vision-instruct `
|
||||
- `meta/llama-3.3-70b-instruct `
|
||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||
- `nvidia/nv-embedqa-e5-v5 `
|
||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||
|
|
|
@ -158,7 +158,7 @@ export ENABLE_PGVECTOR=__disabled__
|
|||
The starter distribution uses several patterns for provider IDs:
|
||||
|
||||
1. **Direct provider IDs**: `faiss`, `ollama`, `vllm`
|
||||
2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC+sqlite-vec}`
|
||||
2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC:+sqlite-vec}`
|
||||
3. **Model-based provider IDs**: `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`
|
||||
|
||||
When using the `+` pattern (like `${env.ENABLE_SQLITE_VEC+sqlite-vec}`), the provider is enabled by default and can be disabled by setting the environment variable to `__disabled__`.
|
||||
|
|
|
@ -59,7 +59,7 @@ Now let's build and run the Llama Stack config for Ollama.
|
|||
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
|
||||
|
||||
```bash
|
||||
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL="llama3.2:3b" llama stack build --template starter --image-type venv --run
|
||||
llama stack build --template starter --image-type venv --run
|
||||
```
|
||||
:::
|
||||
:::{tab-item} Using `conda`
|
||||
|
@ -70,7 +70,7 @@ which defines the providers and their settings.
|
|||
Now let's build and run the Llama Stack config for Ollama.
|
||||
|
||||
```bash
|
||||
ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template starter --image-type conda --run
|
||||
llama stack build --template starter --image-type conda --run
|
||||
```
|
||||
:::
|
||||
:::{tab-item} Using a Container
|
||||
|
@ -80,8 +80,6 @@ component that works with different inference providers out of the box. For this
|
|||
configurations, please check out [this guide](../distributions/building_distro.md).
|
||||
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
||||
```bash
|
||||
export INFERENCE_MODEL="llama3.2:3b"
|
||||
export ENABLE_OLLAMA=ollama
|
||||
export LLAMA_STACK_PORT=8321
|
||||
mkdir -p ~/.llama
|
||||
```
|
||||
|
@ -94,7 +92,6 @@ docker run -it \
|
|||
-v ~/.llama:/root/.llama \
|
||||
llamastack/distribution-starter \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
||||
```
|
||||
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
||||
|
@ -116,7 +113,6 @@ docker run -it \
|
|||
--network=host \
|
||||
llamastack/distribution-starter \
|
||||
--port $LLAMA_STACK_PORT \
|
||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||
--env OLLAMA_URL=http://localhost:11434
|
||||
```
|
||||
:::
|
||||
|
|
|
@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
|
|||
#### Step 2: Run the Llama Stack server
|
||||
We will use `uv` to run the Llama Stack server.
|
||||
```bash
|
||||
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
||||
uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
||||
```
|
||||
#### Step 3: Run the demo
|
||||
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
||||
|
|
|
@ -13,7 +13,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
api_key: ${env.ANTHROPIC_API_KEY}
|
||||
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
|||
|
||||
```yaml
|
||||
base_url: https://api.cerebras.ai
|
||||
api_key: ${env.CEREBRAS_API_KEY}
|
||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@ Databricks inference provider for running models on Databricks' unified analytic
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
url: ${env.DATABRICKS_URL}
|
||||
api_token: ${env.DATABRICKS_API_TOKEN}
|
||||
url: ${env.DATABRICKS_URL:=}
|
||||
api_token: ${env.DATABRICKS_API_TOKEN:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
|||
|
||||
```yaml
|
||||
url: https://api.fireworks.ai/inference/v1
|
||||
api_key: ${env.FIREWORKS_API_KEY}
|
||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
api_key: ${env.GEMINI_API_KEY}
|
||||
api_key: ${env.GEMINI_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
|||
|
||||
```yaml
|
||||
url: https://api.groq.com
|
||||
api_key: ${env.GROQ_API_KEY}
|
||||
api_key: ${env.GROQ_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
api_key: ${env.OPENAI_API_KEY}
|
||||
api_key: ${env.OPENAI_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API
|
|||
|
||||
```yaml
|
||||
openai_compat_api_base: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
|||
|
||||
```yaml
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -13,7 +13,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|
|||
## Sample Configuration
|
||||
|
||||
```yaml
|
||||
url: ${env.TGI_URL}
|
||||
url: ${env.TGI_URL:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@ Together AI inference provider for open-source models and collaborative AI devel
|
|||
|
||||
```yaml
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ SambaNova's safety provider for content moderation and safety filtering.
|
|||
|
||||
```yaml
|
||||
url: https://api.sambanova.ai/v1
|
||||
api_key: ${env.SAMBANOVA_API_KEY}
|
||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||
|
||||
```
|
||||
|
||||
|
|
|
@ -25,7 +25,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
async def refresh(self) -> None:
|
||||
for provider_id, provider in self.impls_by_provider_id.items():
|
||||
refresh = await provider.should_refresh_models()
|
||||
if not (refresh or provider_id in self.listed_providers):
|
||||
refresh = refresh or provider_id not in self.listed_providers
|
||||
if not refresh:
|
||||
continue
|
||||
|
||||
try:
|
||||
|
@ -138,6 +139,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
|||
# avoid overwriting a non-provider-registered model entry
|
||||
continue
|
||||
|
||||
if model.identifier == model.provider_resource_id:
|
||||
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||
|
||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||
await self.register_object(
|
||||
ModelWithOwner(
|
||||
|
|
|
@ -611,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
|
|||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
if (
|
||||
obj.get("provider_id") == "__disabled__"
|
||||
or obj.get("shield_id") == "__disabled__"
|
||||
or obj.get("provider_model_id") == "__disabled__"
|
||||
):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
|
|
|
@ -105,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
|||
method = getattr(impls[api], register_method)
|
||||
for obj in objects:
|
||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
# In complex templates, like our starter template, we may have dynamic model ids
|
||||
# given by environment variables. This allows those environment variables to have
|
||||
# a default value of __disabled__ to skip registration of the model if not set.
|
||||
if (
|
||||
hasattr(obj, "provider_model_id")
|
||||
and obj.provider_model_id is not None
|
||||
and "__disabled__" in obj.provider_model_id
|
||||
):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
||||
continue
|
||||
|
||||
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
|
||||
# Do not register models on disabled providers
|
||||
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||
continue
|
||||
|
||||
# we want to maintain the type information in arguments to method.
|
||||
|
@ -331,8 +318,10 @@ async def construct_stack(
|
|||
|
||||
await register_resources(run_config, impls)
|
||||
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
global REGISTRY_REFRESH_TASK
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
|
||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||
|
||||
def cb(task):
|
||||
import traceback
|
||||
|
@ -368,11 +357,17 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
|||
REGISTRY_REFRESH_TASK.cancel()
|
||||
|
||||
|
||||
async def refresh_registry(impls: dict[Api, Any]):
|
||||
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||
logger.info("refreshing registry")
|
||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
|
||||
|
||||
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||
logger.info("starting registry refresh task")
|
||||
while True:
|
||||
for routing_table in routing_tables:
|
||||
await routing_table.refresh()
|
||||
await refresh_registry_once(impls)
|
||||
|
||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||
|
||||
|
|
|
@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol):
|
|||
-> Provider uses provider-model-id for inference
|
||||
"""
|
||||
|
||||
# this should be called `on_model_register` or something like that.
|
||||
# the provider should _not_ be able to change the object in this
|
||||
# callback
|
||||
async def register_model(self, model: Model) -> Model: ...
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None: ...
|
||||
|
|
|
@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
|||
pass
|
||||
|
||||
async def register_shield(self, shield: Shield) -> None:
|
||||
# Allow any model to be registered as a shield
|
||||
# The model will be validated during runtime when making inference calls
|
||||
pass
|
||||
model_id = shield.provider_resource_id
|
||||
if not model_id:
|
||||
raise ValueError("Llama Guard shield must have a model id")
|
||||
|
||||
async def run_shield(
|
||||
self,
|
||||
|
|
|
@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="anthropic",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="anthropic_api_key",
|
||||
)
|
||||
|
|
|
@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"anthropic/claude-3-5-sonnet-latest",
|
||||
"anthropic/claude-3-7-sonnet-latest",
|
||||
"anthropic/claude-3-5-haiku-latest",
|
||||
"claude-3-5-sonnet-latest",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-5-haiku-latest",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3",
|
||||
provider_model_id="voyage-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-3-lite",
|
||||
provider_model_id="voyage-3-lite",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="anthropic/voyage-code-3",
|
||||
provider_model_id="voyage-code-3",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||
),
|
||||
|
|
|
@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
|
|||
def __init__(self, config: BedrockConfig) -> None:
|
||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||
self._config = config
|
||||
|
||||
self._client = create_bedrock_client(config)
|
||||
self._client = None
|
||||
|
||||
@property
|
||||
def client(self) -> BaseClient:
|
||||
if self._client is None:
|
||||
self._client = create_bedrock_client(self._config)
|
||||
return self._client
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
if self._client is not None:
|
||||
self._client.close()
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
|
@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
|
|||
)
|
||||
self.config = config
|
||||
|
||||
# TODO: make this use provider data, etc. like other providers
|
||||
self.client = AsyncCerebras(
|
||||
base_url=self.config.base_url,
|
||||
api_key=self.config.api_key.get_secret_value(),
|
||||
|
|
|
@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"base_url": DEFAULT_BASE_URL,
|
||||
"api_key": api_key,
|
||||
|
|
|
@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.DATABRICKS_URL}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
||||
url: str = "${env.DATABRICKS_URL:=}",
|
||||
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
|
|
|
@ -24,7 +24,7 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="gemini",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="gemini_api_key",
|
||||
)
|
||||
|
|
|
@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
)
|
||||
|
||||
LLM_MODEL_IDS = [
|
||||
"gemini/gemini-1.5-flash",
|
||||
"gemini/gemini-1.5-pro",
|
||||
"gemini/gemini-2.0-flash",
|
||||
"gemini/gemini-2.5-flash",
|
||||
"gemini/gemini-2.5-pro",
|
||||
"gemini-1.5-flash",
|
||||
"gemini-1.5-pro",
|
||||
"gemini-2.0-flash",
|
||||
"gemini-2.5-flash",
|
||||
"gemini-2.5-pro",
|
||||
]
|
||||
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
|
|||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||
+ [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="gemini/text-embedding-004",
|
||||
provider_model_id="text-embedding-004",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||
),
|
||||
|
|
|
@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.groq.com",
|
||||
"api_key": api_key,
|
||||
|
|
|
@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="groq",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
)
|
||||
|
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
tool_choice = "required"
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
||||
model=model_obj.provider_resource_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
|
|
|
@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
|
|||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-8b-8192",
|
||||
"llama3-8b-8192",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_model_entry(
|
||||
"groq/llama-3.1-8b-instant",
|
||||
"llama-3.1-8b-instant",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama3-70b-8192",
|
||||
"llama3-70b-8192",
|
||||
CoreModelId.llama3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
"llama-3.3-70b-versatile",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
# Groq only contains a preview version for llama-3.2-3b
|
||||
|
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
|
|||
# to pass the test fixture
|
||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-3.2-3b-preview",
|
||||
"llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -32,6 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="llama",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="llama_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
|
|
|
@ -166,7 +166,7 @@ class OllamaInferenceAdapter(
|
|||
]
|
||||
for m in response.models:
|
||||
# kill embedding models since we don't know dimensions for them
|
||||
if m.details.family in ["bert"]:
|
||||
if "bert" in m.details.family:
|
||||
continue
|
||||
models.append(
|
||||
Model(
|
||||
|
@ -420,9 +420,6 @@ class OllamaInferenceAdapter(
|
|||
except ValueError:
|
||||
pass # Ignore statically unknown model, will check live listing
|
||||
|
||||
if model.provider_resource_id is None:
|
||||
raise ValueError("Model provider_resource_id cannot be None")
|
||||
|
||||
if model.model_type == ModelType.embedding:
|
||||
response = await self.client.list()
|
||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||
|
@ -433,9 +430,9 @@ class OllamaInferenceAdapter(
|
|||
# - models not currently running are run by the ollama server as needed
|
||||
response = await self.client.list()
|
||||
available_models = [m.model for m in response.models]
|
||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
||||
if provider_resource_id is None:
|
||||
provider_resource_id = model.provider_resource_id
|
||||
|
||||
provider_resource_id = model.provider_resource_id
|
||||
assert provider_resource_id is not None # mypy
|
||||
if provider_resource_id not in available_models:
|
||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||
if provider_resource_id in available_models_latest:
|
||||
|
@ -443,7 +440,9 @@ class OllamaInferenceAdapter(
|
|||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||
)
|
||||
return model
|
||||
raise UnsupportedModelError(model.provider_resource_id, available_models)
|
||||
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||
|
||||
# mutating this should be considered an anti-pattern
|
||||
model.provider_resource_id = provider_resource_id
|
||||
|
||||
return model
|
||||
|
|
|
@ -26,7 +26,7 @@ class OpenAIConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"api_key": api_key,
|
||||
}
|
||||
|
|
|
@ -45,6 +45,7 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
MODEL_ENTRIES,
|
||||
litellm_provider_name="openai",
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="openai_api_key",
|
||||
)
|
||||
|
|
|
@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
|
|||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
SAFETY_MODELS_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-Guard-3-8B",
|
||||
CoreModelId.llama_guard_3_8b.value,
|
||||
),
|
||||
]
|
||||
SAFETY_MODELS_ENTRIES = []
|
||||
|
||||
|
||||
MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
||||
"Meta-Llama-3.1-8B-Instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
||||
CoreModelId.llama3_1_405b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
||||
CoreModelId.llama3_2_1b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
||||
"Meta-Llama-3.3-70B-Instruct",
|
||||
CoreModelId.llama3_3_70b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
||||
"Llama-4-Maverick-17B-128E-Instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
litellm_provider_name="sambanova",
|
||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||
provider_data_api_key_field="sambanova_api_key",
|
||||
)
|
||||
|
|
|
@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
url: str = "${env.TGI_URL}",
|
||||
url: str = "${env.TGI_URL:=}",
|
||||
**kwargs,
|
||||
):
|
||||
return {
|
||||
|
|
|
@ -305,6 +305,8 @@ class _HfAdapter(
|
|||
|
||||
class TGIAdapter(_HfAdapter):
|
||||
async def initialize(self, config: TGIImplConfig) -> None:
|
||||
if not config.url:
|
||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||
log.info(f"Initializing TGI client with url={config.url}")
|
||||
self.client = AsyncInferenceClient(
|
||||
model=config.url,
|
||||
|
|
|
@ -27,5 +27,5 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
|||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.together.xyz/v1",
|
||||
"api_key": "${env.TOGETHER_API_KEY}",
|
||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||
}
|
||||
|
|
|
@ -69,15 +69,9 @@ MODEL_ENTRIES = [
|
|||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||
],
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
additional_aliases=[
|
||||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||
],
|
||||
),
|
||||
] + SAFETY_MODELS_ENTRIES
|
||||
|
|
|
@ -299,7 +299,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
self.client = None
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||
)
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return self.config.refresh_models
|
||||
|
@ -337,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
HealthResponse: A dictionary containing the health status.
|
||||
"""
|
||||
try:
|
||||
if not self.config.url:
|
||||
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
|
||||
|
||||
client = self._create_client() if self.client is None else self.client
|
||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||
return HealthResponse(status=HealthStatus.OK)
|
||||
|
@ -355,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
if self.client is not None:
|
||||
return
|
||||
|
||||
if not self.config.url:
|
||||
raise ValueError(
|
||||
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
|
||||
)
|
||||
|
||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||
self.client = self._create_client()
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||
return {
|
||||
"url": "https://api.sambanova.ai/v1",
|
||||
"api_key": api_key,
|
||||
|
|
|
@ -68,11 +68,14 @@ class LiteLLMOpenAIMixin(
|
|||
def __init__(
|
||||
self,
|
||||
model_entries,
|
||||
litellm_provider_name: str,
|
||||
api_key_from_config: str | None,
|
||||
provider_data_api_key_field: str,
|
||||
openai_compat_api_base: str | None = None,
|
||||
):
|
||||
ModelRegistryHelper.__init__(self, model_entries)
|
||||
|
||||
self.litellm_provider_name = litellm_provider_name
|
||||
self.api_key_from_config = api_key_from_config
|
||||
self.provider_data_api_key_field = provider_data_api_key_field
|
||||
self.api_base = openai_compat_api_base
|
||||
|
@ -91,7 +94,11 @@ class LiteLLMOpenAIMixin(
|
|||
def get_litellm_model_name(self, model_id: str) -> str:
|
||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||
# model_id.startswith("openai/") is for backwards compatibility.
|
||||
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
|
||||
return (
|
||||
f"{self.litellm_provider_name}/{model_id}"
|
||||
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
|
||||
else model_id
|
||||
)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
|
|
|
@ -50,7 +50,8 @@ def build_hf_repo_model_entry(
|
|||
additional_aliases: list[str] | None = None,
|
||||
) -> ProviderModelEntry:
|
||||
aliases = [
|
||||
get_huggingface_repo(model_descriptor),
|
||||
# NOTE: avoid HF aliases because they _cannot_ be unique across providers
|
||||
# get_huggingface_repo(model_descriptor),
|
||||
]
|
||||
if additional_aliases:
|
||||
aliases.extend(additional_aliases)
|
||||
|
@ -75,7 +76,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
__provider_id__: str
|
||||
|
||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||
self.model_entries = model_entries
|
||||
self.allowed_models = allowed_models
|
||||
|
||||
self.alias_to_provider_id_map = {}
|
||||
self.provider_id_to_llama_model_map = {}
|
||||
for entry in model_entries:
|
||||
|
@ -98,7 +101,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
continue
|
||||
models.append(
|
||||
Model(
|
||||
model_id=id,
|
||||
identifier=id,
|
||||
provider_resource_id=entry.provider_model_id,
|
||||
model_type=ModelType.llm,
|
||||
metadata=entry.metadata,
|
||||
|
@ -185,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
|||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
# TODO: should we block unregistering base supported provider model IDs?
|
||||
if model_id not in self.alias_to_provider_id_map:
|
||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
||||
|
||||
del self.alias_to_provider_id_map[model_id]
|
||||
# model_id is the identifier, not the provider_resource_id
|
||||
# unfortunately, this ID can be of the form provider_id/model_id which
|
||||
# we never registered. TODO: fix this by significantly rewriting
|
||||
# registration and registry helper
|
||||
pass
|
||||
|
|
|
@ -7,21 +7,15 @@ distribution_spec:
|
|||
- provider_type: remote::ollama
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: remote::hf::serverless
|
||||
- provider_type: remote::hf::endpoint
|
||||
- provider_type: remote::fireworks
|
||||
- provider_type: remote::together
|
||||
- provider_type: remote::bedrock
|
||||
- provider_type: remote::databricks
|
||||
- provider_type: remote::nvidia
|
||||
- provider_type: remote::runpod
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::llama-openai-compat
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: remote::passthrough
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -89,101 +89,51 @@ models:
|
|||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-8b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-8b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.1-405b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.1-405b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-1b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-1b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-3b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-3b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-11b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.2-90b-vision-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta/llama-3.3-70b-instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata: {}
|
||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||
provider_id: nvidia
|
||||
provider_model_id: meta/llama-3.3-70b-instruct
|
||||
model_type: llm
|
||||
- metadata:
|
||||
embedding_dimension: 2048
|
||||
context_length: 8192
|
||||
|
|
|
@ -33,7 +33,7 @@ providers:
|
|||
provider_type: remote::together
|
||||
config:
|
||||
url: https://api.together.xyz/v1
|
||||
api_key: ${env.TOGETHER_API_KEY}
|
||||
api_key: ${env.TOGETHER_API_KEY:=}
|
||||
vector_io:
|
||||
- provider_id: sqlite-vec
|
||||
provider_type: inline::sqlite-vec
|
||||
|
|
|
@ -7,21 +7,15 @@ distribution_spec:
|
|||
- provider_type: remote::ollama
|
||||
- provider_type: remote::vllm
|
||||
- provider_type: remote::tgi
|
||||
- provider_type: remote::hf::serverless
|
||||
- provider_type: remote::hf::endpoint
|
||||
- provider_type: remote::fireworks
|
||||
- provider_type: remote::together
|
||||
- provider_type: remote::bedrock
|
||||
- provider_type: remote::databricks
|
||||
- provider_type: remote::nvidia
|
||||
- provider_type: remote::runpod
|
||||
- provider_type: remote::openai
|
||||
- provider_type: remote::anthropic
|
||||
- provider_type: remote::gemini
|
||||
- provider_type: remote::groq
|
||||
- provider_type: remote::llama-openai-compat
|
||||
- provider_type: remote::sambanova
|
||||
- provider_type: remote::passthrough
|
||||
- provider_type: inline::sentence-transformers
|
||||
vector_io:
|
||||
- provider_type: inline::faiss
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -7,20 +7,19 @@
|
|||
|
||||
from typing import Any
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.distribution.datatypes import (
|
||||
BuildProvider,
|
||||
ModelInput,
|
||||
Provider,
|
||||
ProviderSpec,
|
||||
ShieldInput,
|
||||
ToolGroupInput,
|
||||
)
|
||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||
SentenceTransformersInferenceConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
||||
MilvusVectorIOConfig,
|
||||
|
@ -29,117 +28,17 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
|||
SQLiteVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.registry.inference import available_providers
|
||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.gemini.models import (
|
||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.groq.models import (
|
||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.openai.models import (
|
||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.inference.together.models import (
|
||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||
PGVectorVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||
from llama_stack.templates.template import (
|
||||
DistributionTemplate,
|
||||
RunConfigSettings,
|
||||
get_model_registry,
|
||||
get_shield_registry,
|
||||
)
|
||||
|
||||
|
||||
def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
model_entries_map = {
|
||||
"openai": OPENAI_MODEL_ENTRIES,
|
||||
"fireworks": FIREWORKS_MODEL_ENTRIES,
|
||||
"together": TOGETHER_MODEL_ENTRIES,
|
||||
"anthropic": ANTHROPIC_MODEL_ENTRIES,
|
||||
"gemini": GEMINI_MODEL_ENTRIES,
|
||||
"groq": GROQ_MODEL_ENTRIES,
|
||||
"sambanova": SAMBANOVA_MODEL_ENTRIES,
|
||||
"cerebras": CEREBRAS_MODEL_ENTRIES,
|
||||
"bedrock": BEDROCK_MODEL_ENTRIES,
|
||||
"databricks": DATABRICKS_MODEL_ENTRIES,
|
||||
"nvidia": NVIDIA_MODEL_ENTRIES,
|
||||
"runpod": RUNPOD_MODEL_ENTRIES,
|
||||
}
|
||||
|
||||
# Special handling for providers with dynamic model entries
|
||||
if provider_type == "ollama":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
||||
},
|
||||
),
|
||||
]
|
||||
elif provider_type == "vllm":
|
||||
return [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
]
|
||||
|
||||
return model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||
"""Get model entries for a specific provider type."""
|
||||
safety_model_entries_map = {
|
||||
"ollama": [
|
||||
ProviderModelEntry(
|
||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||
model_type=ModelType.llm,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
return safety_model_entries_map.get(provider_type, [])
|
||||
|
||||
|
||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||
"""Get configuration for a provider using its adapter's config class."""
|
||||
config_class = instantiate_class_type(provider_spec.config_class)
|
||||
|
@ -150,40 +49,48 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
|||
return {}
|
||||
|
||||
|
||||
def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
||||
all_providers = available_providers()
|
||||
ENABLED_INFERENCE_PROVIDERS = [
|
||||
"ollama",
|
||||
"vllm",
|
||||
"tgi",
|
||||
"fireworks",
|
||||
"together",
|
||||
"gemini",
|
||||
"groq",
|
||||
"sambanova",
|
||||
"anthropic",
|
||||
"openai",
|
||||
"cerebras",
|
||||
"nvidia",
|
||||
"bedrock",
|
||||
]
|
||||
|
||||
# Filter out inline providers and watsonx - the starter distro only exposes remote providers
|
||||
INFERENCE_PROVIDER_IDS = {
|
||||
"vllm": "${env.VLLM_URL:+vllm}",
|
||||
"tgi": "${env.TGI_URL:+tgi}",
|
||||
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
|
||||
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
|
||||
}
|
||||
|
||||
|
||||
def get_remote_inference_providers() -> list[Provider]:
|
||||
# Filter out inline providers and some others - the starter distro only exposes remote providers
|
||||
remote_providers = [
|
||||
provider
|
||||
for provider in all_providers
|
||||
# TODO: re-add once the Python 3.13 issue is fixed
|
||||
# discussion: https://github.com/meta-llama/llama-stack/pull/2327#discussion_r2156883828
|
||||
if hasattr(provider, "adapter") and provider.adapter.adapter_type != "watsonx"
|
||||
for provider in available_providers()
|
||||
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||
]
|
||||
|
||||
providers = []
|
||||
available_models = {}
|
||||
|
||||
inference_providers = []
|
||||
for provider_spec in remote_providers:
|
||||
provider_type = provider_spec.adapter.adapter_type
|
||||
|
||||
# Build the environment variable name for enabling this provider
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
model_entries = _get_model_entries_for_provider(provider_type)
|
||||
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||
else:
|
||||
provider_id = provider_type.replace("-", "_").replace("::", "_")
|
||||
config = _get_config_for_provider(provider_spec)
|
||||
providers.append(
|
||||
(
|
||||
f"${{env.{env_var}:=__disabled__}}",
|
||||
provider_type,
|
||||
model_entries,
|
||||
config,
|
||||
)
|
||||
)
|
||||
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
|
||||
|
||||
inference_providers = []
|
||||
for provider_id, provider_type, model_entries, config in providers:
|
||||
inference_providers.append(
|
||||
Provider(
|
||||
provider_id=provider_id,
|
||||
|
@ -191,31 +98,13 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
|||
config=config,
|
||||
)
|
||||
)
|
||||
available_models[provider_id] = model_entries
|
||||
return inference_providers, available_models
|
||||
|
||||
|
||||
# build a list of shields for all possible providers
|
||||
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||
available_models = {}
|
||||
for provider in providers:
|
||||
provider_type = provider.provider_type.split("::")[1]
|
||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||
if len(safety_model_entries) == 0:
|
||||
continue
|
||||
|
||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||
|
||||
available_models[provider_id] = safety_model_entries
|
||||
|
||||
return available_models
|
||||
return inference_providers
|
||||
|
||||
|
||||
def get_distribution_template() -> DistributionTemplate:
|
||||
remote_inference_providers, available_models = get_remote_inference_providers()
|
||||
remote_inference_providers = get_remote_inference_providers()
|
||||
name = "starter"
|
||||
# For build config, use BuildProvider with only provider_type and module
|
||||
|
||||
providers = {
|
||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||
+ [BuildProvider(provider_type="inline::sentence-transformers")],
|
||||
|
@ -254,15 +143,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
embedding_provider = Provider(
|
||||
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
|
||||
provider_id="sentence-transformers",
|
||||
provider_type="inline::sentence-transformers",
|
||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||
)
|
||||
post_training_provider = Provider(
|
||||
provider_id="huggingface",
|
||||
provider_type="inline::huggingface",
|
||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
)
|
||||
default_tool_groups = [
|
||||
ToolGroupInput(
|
||||
toolgroup_id="builtin::websearch",
|
||||
|
@ -273,19 +157,14 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
provider_id="rag-runtime",
|
||||
),
|
||||
]
|
||||
embedding_model = ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
provider_id=embedding_provider.provider_id,
|
||||
model_type=ModelType.embedding,
|
||||
metadata={
|
||||
"embedding_dimension": 384,
|
||||
},
|
||||
)
|
||||
|
||||
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||
|
||||
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||
default_shields = [
|
||||
# if the
|
||||
ShieldInput(
|
||||
shield_id="llama-guard",
|
||||
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||
),
|
||||
]
|
||||
|
||||
return DistributionTemplate(
|
||||
name=name,
|
||||
|
@ -294,7 +173,6 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
container_image=None,
|
||||
template_path=None,
|
||||
providers=providers,
|
||||
available_models_by_provider=available_models,
|
||||
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||
run_configs={
|
||||
"run.yaml": RunConfigSettings(
|
||||
|
@ -302,22 +180,22 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"inference": remote_inference_providers + [embedding_provider],
|
||||
"vector_io": [
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_FAISS:=faiss}",
|
||||
provider_id="faiss",
|
||||
provider_type="inline::faiss",
|
||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
|
||||
provider_id="sqlite-vec",
|
||||
provider_type="inline::sqlite-vec",
|
||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
|
||||
provider_id="${env.MILVUS_URL:+milvus}",
|
||||
provider_type="inline::milvus",
|
||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
|
||||
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||
provider_type="remote::chromadb",
|
||||
config=ChromaVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}/",
|
||||
|
@ -325,7 +203,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
),
|
||||
Provider(
|
||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||
provider_type="remote::pgvector",
|
||||
config=PGVectorVectorIOConfig.sample_run_config(
|
||||
f"~/.llama/distributions/{name}",
|
||||
|
@ -336,12 +214,10 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
),
|
||||
],
|
||||
"files": [files_provider],
|
||||
"post_training": [post_training_provider],
|
||||
},
|
||||
default_models=[embedding_model] + default_models,
|
||||
default_models=[],
|
||||
default_tool_groups=default_tool_groups,
|
||||
# TODO: add a way to enable/disable shields on the fly
|
||||
default_shields=shields,
|
||||
default_shields=default_shields,
|
||||
),
|
||||
},
|
||||
run_config_env_vars={
|
||||
|
@ -385,17 +261,5 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"http://localhost:11434",
|
||||
"Ollama URL",
|
||||
),
|
||||
"OLLAMA_INFERENCE_MODEL": (
|
||||
"",
|
||||
"Optional Ollama Inference Model to register on startup",
|
||||
),
|
||||
"OLLAMA_EMBEDDING_MODEL": (
|
||||
"",
|
||||
"Optional Ollama Embedding Model to register on startup",
|
||||
),
|
||||
"OLLAMA_EMBEDDING_DIMENSION": (
|
||||
"384",
|
||||
"Ollama Embedding Dimension",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
|
|
@ -25,7 +25,7 @@ dependencies = [
|
|||
"fastapi>=0.115.0,<1.0", # server
|
||||
"fire", # for MCP in LLS client
|
||||
"httpx",
|
||||
"huggingface-hub>=0.30.0,<1.0",
|
||||
"huggingface-hub>=0.34.0,<1.0",
|
||||
"jinja2>=3.1.6",
|
||||
"jsonschema",
|
||||
"llama-stack-client>=0.2.15",
|
||||
|
|
|
@ -86,7 +86,7 @@ httpx==0.28.1
|
|||
# llama-stack
|
||||
# llama-stack-client
|
||||
# openai
|
||||
huggingface-hub==0.33.0
|
||||
huggingface-hub==0.34.1
|
||||
# via llama-stack
|
||||
idna==3.10
|
||||
# via
|
||||
|
|
|
@ -222,9 +222,7 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
|||
--network llama-net \
|
||||
-p "${PORT}:${PORT}" \
|
||||
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||
--env OLLAMA_INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
||||
--env ENABLE_OLLAMA=ollama)
|
||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}")
|
||||
|
||||
log "🦙 Starting Llama Stack..."
|
||||
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
||||
|
|
|
@ -502,7 +502,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
|||
|
||||
# Find the user model and provider model
|
||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
||||
|
||||
assert user_model is not None
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
|
@ -558,12 +558,12 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
|||
|
||||
identifiers = {m.identifier for m in models.data}
|
||||
assert "test_provider/user-model" in identifiers # User model preserved
|
||||
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||
assert "provider-model-old" not in identifiers # Old provider model removed
|
||||
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
|
||||
|
||||
# Verify sources are correct
|
||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
|
||||
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
|
||||
|
||||
assert user_model.source == RegistryEntrySource.via_register_api
|
||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue