mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
feat(starter)!: simplify starter distro; litellm model registry changes (#2916)
This commit is contained in:
parent
3344d8a9e5
commit
9583f468f8
64 changed files with 2027 additions and 4092 deletions
12
.github/workflows/integration-tests.yml
vendored
12
.github/workflows/integration-tests.yml
vendored
|
@ -117,17 +117,13 @@ jobs:
|
||||||
|
|
||||||
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
|
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
|
||||||
if [ "${{ matrix.provider }}" == "ollama" ]; then
|
if [ "${{ matrix.provider }}" == "ollama" ]; then
|
||||||
export ENABLE_OLLAMA="ollama"
|
|
||||||
export OLLAMA_URL="http://0.0.0.0:11434"
|
export OLLAMA_URL="http://0.0.0.0:11434"
|
||||||
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16"
|
export TEXT_MODEL=ollama/llama3.2:3b-instruct-fp16
|
||||||
export TEXT_MODEL=ollama/$OLLAMA_INFERENCE_MODEL
|
export SAFETY_MODEL="ollama/llama-guard3:1b"
|
||||||
export SAFETY_MODEL="llama-guard3:1b"
|
EXTRA_PARAMS="--safety-shield=llama-guard"
|
||||||
EXTRA_PARAMS="--safety-shield=$SAFETY_MODEL"
|
|
||||||
else
|
else
|
||||||
export ENABLE_VLLM="vllm"
|
|
||||||
export VLLM_URL="http://localhost:8000/v1"
|
export VLLM_URL="http://localhost:8000/v1"
|
||||||
export VLLM_INFERENCE_MODEL="meta-llama/Llama-3.2-1B-Instruct"
|
export TEXT_MODEL=vllm/meta-llama/Llama-3.2-1B-Instruct
|
||||||
export TEXT_MODEL=vllm/$VLLM_INFERENCE_MODEL
|
|
||||||
# TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently
|
# TODO: remove the not(test_inference_store_tool_calls) once we can get the tool called consistently
|
||||||
EXTRA_PARAMS=
|
EXTRA_PARAMS=
|
||||||
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
|
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
|
||||||
|
|
|
@ -249,12 +249,6 @@
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient\n",
|
"from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient\n",
|
||||||
"import os\n",
|
|
||||||
"\n",
|
|
||||||
"os.environ[\"ENABLE_OLLAMA\"] = \"ollama\"\n",
|
|
||||||
"os.environ[\"OLLAMA_INFERENCE_MODEL\"] = \"llama3.2:3b\"\n",
|
|
||||||
"os.environ[\"OLLAMA_EMBEDDING_MODEL\"] = \"all-minilm:l6-v2\"\n",
|
|
||||||
"os.environ[\"OLLAMA_EMBEDDING_DIMENSION\"] = \"384\"\n",
|
|
||||||
"\n",
|
"\n",
|
||||||
"vector_db_id = \"my_demo_vector_db\"\n",
|
"vector_db_id = \"my_demo_vector_db\"\n",
|
||||||
"client = LlamaStackClient(base_url=\"http://0.0.0.0:8321\")\n",
|
"client = LlamaStackClient(base_url=\"http://0.0.0.0:8321\")\n",
|
||||||
|
|
|
@ -40,16 +40,16 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
- `meta/llama3-8b-instruct `
|
||||||
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
- `meta/llama3-70b-instruct `
|
||||||
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
- `meta/llama-3.1-8b-instruct `
|
||||||
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
- `meta/llama-3.1-70b-instruct `
|
||||||
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
- `meta/llama-3.1-405b-instruct `
|
||||||
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
- `meta/llama-3.2-1b-instruct `
|
||||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `meta/llama-3.2-3b-instruct `
|
||||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `meta/llama-3.2-11b-vision-instruct `
|
||||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `meta/llama-3.2-90b-vision-instruct `
|
||||||
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `meta/llama-3.3-70b-instruct `
|
||||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
- `nvidia/nv-embedqa-e5-v5 `
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
|
|
|
@ -158,7 +158,7 @@ export ENABLE_PGVECTOR=__disabled__
|
||||||
The starter distribution uses several patterns for provider IDs:
|
The starter distribution uses several patterns for provider IDs:
|
||||||
|
|
||||||
1. **Direct provider IDs**: `faiss`, `ollama`, `vllm`
|
1. **Direct provider IDs**: `faiss`, `ollama`, `vllm`
|
||||||
2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC+sqlite-vec}`
|
2. **Environment-based provider IDs**: `${env.ENABLE_SQLITE_VEC:+sqlite-vec}`
|
||||||
3. **Model-based provider IDs**: `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`
|
3. **Model-based provider IDs**: `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`
|
||||||
|
|
||||||
When using the `+` pattern (like `${env.ENABLE_SQLITE_VEC+sqlite-vec}`), the provider is enabled by default and can be disabled by setting the environment variable to `__disabled__`.
|
When using the `+` pattern (like `${env.ENABLE_SQLITE_VEC+sqlite-vec}`), the provider is enabled by default and can be disabled by setting the environment variable to `__disabled__`.
|
||||||
|
|
|
@ -59,7 +59,7 @@ Now let's build and run the Llama Stack config for Ollama.
|
||||||
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
|
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL="llama3.2:3b" llama stack build --template starter --image-type venv --run
|
llama stack build --template starter --image-type venv --run
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
:::{tab-item} Using `conda`
|
:::{tab-item} Using `conda`
|
||||||
|
@ -70,7 +70,7 @@ which defines the providers and their settings.
|
||||||
Now let's build and run the Llama Stack config for Ollama.
|
Now let's build and run the Llama Stack config for Ollama.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template starter --image-type conda --run
|
llama stack build --template starter --image-type conda --run
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
:::{tab-item} Using a Container
|
:::{tab-item} Using a Container
|
||||||
|
@ -80,8 +80,6 @@ component that works with different inference providers out of the box. For this
|
||||||
configurations, please check out [this guide](../distributions/building_distro.md).
|
configurations, please check out [this guide](../distributions/building_distro.md).
|
||||||
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
||||||
```bash
|
```bash
|
||||||
export INFERENCE_MODEL="llama3.2:3b"
|
|
||||||
export ENABLE_OLLAMA=ollama
|
|
||||||
export LLAMA_STACK_PORT=8321
|
export LLAMA_STACK_PORT=8321
|
||||||
mkdir -p ~/.llama
|
mkdir -p ~/.llama
|
||||||
```
|
```
|
||||||
|
@ -94,7 +92,6 @@ docker run -it \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-starter \
|
llamastack/distribution-starter \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
--env OLLAMA_URL=http://host.docker.internal:11434
|
||||||
```
|
```
|
||||||
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
||||||
|
@ -116,7 +113,6 @@ docker run -it \
|
||||||
--network=host \
|
--network=host \
|
||||||
llamastack/distribution-starter \
|
llamastack/distribution-starter \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env OLLAMA_URL=http://localhost:11434
|
--env OLLAMA_URL=http://localhost:11434
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
|
@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
|
||||||
#### Step 2: Run the Llama Stack server
|
#### Step 2: Run the Llama Stack server
|
||||||
We will use `uv` to run the Llama Stack server.
|
We will use `uv` to run the Llama Stack server.
|
||||||
```bash
|
```bash
|
||||||
ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
uv run --with llama-stack llama stack build --template starter --image-type venv --run
|
||||||
```
|
```
|
||||||
#### Step 3: Run the demo
|
#### Step 3: Run the demo
|
||||||
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
||||||
|
|
|
@ -13,7 +13,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
api_key: ${env.ANTHROPIC_API_KEY}
|
api_key: ${env.ANTHROPIC_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai
|
||||||
api_key: ${env.CEREBRAS_API_KEY}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -14,8 +14,8 @@ Databricks inference provider for running models on Databricks' unified analytic
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.DATABRICKS_URL}
|
url: ${env.DATABRICKS_URL:=}
|
||||||
api_token: ${env.DATABRICKS_API_TOKEN}
|
api_token: ${env.DATABRICKS_API_TOKEN:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.fireworks.ai/inference/v1
|
url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
api_key: ${env.GEMINI_API_KEY}
|
api_key: ${env.GEMINI_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.groq.com
|
url: https://api.groq.com
|
||||||
api_key: ${env.GROQ_API_KEY}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
api_key: ${env.OPENAI_API_KEY}
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
openai_compat_api_base: https://api.sambanova.ai/v1
|
openai_compat_api_base: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.sambanova.ai/v1
|
url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.TGI_URL}
|
url: ${env.TGI_URL:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@ Together AI inference provider for open-source models and collaborative AI devel
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.together.xyz/v1
|
url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -15,7 +15,7 @@ SambaNova's safety provider for content moderation and safety filtering.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.sambanova.ai/v1
|
url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,8 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def refresh(self) -> None:
|
async def refresh(self) -> None:
|
||||||
for provider_id, provider in self.impls_by_provider_id.items():
|
for provider_id, provider in self.impls_by_provider_id.items():
|
||||||
refresh = await provider.should_refresh_models()
|
refresh = await provider.should_refresh_models()
|
||||||
if not (refresh or provider_id in self.listed_providers):
|
refresh = refresh or provider_id not in self.listed_providers
|
||||||
|
if not refresh:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -138,6 +139,9 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
# avoid overwriting a non-provider-registered model entry
|
# avoid overwriting a non-provider-registered model entry
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if model.identifier == model.provider_resource_id:
|
||||||
|
model.identifier = f"{provider_id}/{model.provider_resource_id}"
|
||||||
|
|
||||||
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})")
|
||||||
await self.register_object(
|
await self.register_object(
|
||||||
ModelWithOwner(
|
ModelWithOwner(
|
||||||
|
|
|
@ -611,11 +611,8 @@ def extract_path_params(route: str) -> list[str]:
|
||||||
|
|
||||||
def remove_disabled_providers(obj):
|
def remove_disabled_providers(obj):
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
if (
|
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||||
obj.get("provider_id") == "__disabled__"
|
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||||
or obj.get("shield_id") == "__disabled__"
|
|
||||||
or obj.get("provider_model_id") == "__disabled__"
|
|
||||||
):
|
|
||||||
return None
|
return None
|
||||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
|
|
|
@ -105,23 +105,10 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
|
||||||
method = getattr(impls[api], register_method)
|
method = getattr(impls[api], register_method)
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
|
||||||
# Do not register models on disabled providers
|
|
||||||
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
|
|
||||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
|
||||||
continue
|
|
||||||
# In complex templates, like our starter template, we may have dynamic model ids
|
|
||||||
# given by environment variables. This allows those environment variables to have
|
|
||||||
# a default value of __disabled__ to skip registration of the model if not set.
|
|
||||||
if (
|
|
||||||
hasattr(obj, "provider_model_id")
|
|
||||||
and obj.provider_model_id is not None
|
|
||||||
and "__disabled__" in obj.provider_model_id
|
|
||||||
):
|
|
||||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
|
# Do not register models on disabled providers
|
||||||
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
|
if hasattr(obj, "provider_id") and (not obj.provider_id or obj.provider_id == "__disabled__"):
|
||||||
|
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# we want to maintain the type information in arguments to method.
|
# we want to maintain the type information in arguments to method.
|
||||||
|
@ -331,8 +318,10 @@ async def construct_stack(
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
|
|
||||||
|
await refresh_registry_once(impls)
|
||||||
|
|
||||||
global REGISTRY_REFRESH_TASK
|
global REGISTRY_REFRESH_TASK
|
||||||
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry(impls))
|
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
|
||||||
|
|
||||||
def cb(task):
|
def cb(task):
|
||||||
import traceback
|
import traceback
|
||||||
|
@ -368,12 +357,18 @@ async def shutdown_stack(impls: dict[Api, Any]):
|
||||||
REGISTRY_REFRESH_TASK.cancel()
|
REGISTRY_REFRESH_TASK.cancel()
|
||||||
|
|
||||||
|
|
||||||
async def refresh_registry(impls: dict[Api, Any]):
|
async def refresh_registry_once(impls: dict[Api, Any]):
|
||||||
|
logger.info("refreshing registry")
|
||||||
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
routing_tables = [v for v in impls.values() if isinstance(v, CommonRoutingTableImpl)]
|
||||||
while True:
|
|
||||||
for routing_table in routing_tables:
|
for routing_table in routing_tables:
|
||||||
await routing_table.refresh()
|
await routing_table.refresh()
|
||||||
|
|
||||||
|
|
||||||
|
async def refresh_registry_task(impls: dict[Api, Any]):
|
||||||
|
logger.info("starting registry refresh task")
|
||||||
|
while True:
|
||||||
|
await refresh_registry_once(impls)
|
||||||
|
|
||||||
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
await asyncio.sleep(REGISTRY_REFRESH_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,9 @@ class ModelsProtocolPrivate(Protocol):
|
||||||
-> Provider uses provider-model-id for inference
|
-> Provider uses provider-model-id for inference
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# this should be called `on_model_register` or something like that.
|
||||||
|
# the provider should _not_ be able to change the object in this
|
||||||
|
# callback
|
||||||
async def register_model(self, model: Model) -> Model: ...
|
async def register_model(self, model: Model) -> Model: ...
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None: ...
|
async def unregister_model(self, model_id: str) -> None: ...
|
||||||
|
|
|
@ -146,9 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_shield(self, shield: Shield) -> None:
|
async def register_shield(self, shield: Shield) -> None:
|
||||||
# Allow any model to be registered as a shield
|
model_id = shield.provider_resource_id
|
||||||
# The model will be validated during runtime when making inference calls
|
if not model_id:
|
||||||
pass
|
raise ValueError("Llama Guard shield must have a model id")
|
||||||
|
|
||||||
async def run_shield(
|
async def run_shield(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -15,6 +15,7 @@ class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
MODEL_ENTRIES,
|
MODEL_ENTRIES,
|
||||||
|
litellm_provider_name="anthropic",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key,
|
||||||
provider_data_api_key_field="anthropic_api_key",
|
provider_data_api_key_field="anthropic_api_key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -26,7 +26,7 @@ class AnthropicConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.ANTHROPIC_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,9 +10,9 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
)
|
)
|
||||||
|
|
||||||
LLM_MODEL_IDS = [
|
LLM_MODEL_IDS = [
|
||||||
"anthropic/claude-3-5-sonnet-latest",
|
"claude-3-5-sonnet-latest",
|
||||||
"anthropic/claude-3-7-sonnet-latest",
|
"claude-3-7-sonnet-latest",
|
||||||
"anthropic/claude-3-5-haiku-latest",
|
"claude-3-5-haiku-latest",
|
||||||
]
|
]
|
||||||
|
|
||||||
SAFETY_MODELS_ENTRIES = []
|
SAFETY_MODELS_ENTRIES = []
|
||||||
|
@ -21,17 +21,17 @@ MODEL_ENTRIES = (
|
||||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||||
+ [
|
+ [
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="anthropic/voyage-3",
|
provider_model_id="voyage-3",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||||
),
|
),
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="anthropic/voyage-3-lite",
|
provider_model_id="voyage-3-lite",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={"embedding_dimension": 512, "context_length": 32000},
|
metadata={"embedding_dimension": 512, "context_length": 32000},
|
||||||
),
|
),
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="anthropic/voyage-code-3",
|
provider_model_id="voyage-code-3",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
metadata={"embedding_dimension": 1024, "context_length": 32000},
|
||||||
),
|
),
|
||||||
|
|
|
@ -63,18 +63,20 @@ class BedrockInferenceAdapter(
|
||||||
def __init__(self, config: BedrockConfig) -> None:
|
def __init__(self, config: BedrockConfig) -> None:
|
||||||
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
self._config = config
|
self._config = config
|
||||||
|
self._client = None
|
||||||
self._client = create_bedrock_client(config)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def client(self) -> BaseClient:
|
def client(self) -> BaseClient:
|
||||||
|
if self._client is None:
|
||||||
|
self._client = create_bedrock_client(self._config)
|
||||||
return self._client
|
return self._client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
if self._client is not None:
|
||||||
|
self._client.close()
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -65,6 +65,7 @@ class CerebrasInferenceAdapter(
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
# TODO: make this use provider data, etc. like other providers
|
||||||
self.client = AsyncCerebras(
|
self.client = AsyncCerebras(
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
api_key=self.config.api_key.get_secret_value(),
|
api_key=self.config.api_key.get_secret_value(),
|
||||||
|
|
|
@ -26,7 +26,7 @@ class CerebrasImplConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"base_url": DEFAULT_BASE_URL,
|
"base_url": DEFAULT_BASE_URL,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
|
@ -25,8 +25,8 @@ class DatabricksImplConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.DATABRICKS_URL}",
|
url: str = "${env.DATABRICKS_URL:=}",
|
||||||
api_token: str = "${env.DATABRICKS_API_TOKEN}",
|
api_token: str = "${env.DATABRICKS_API_TOKEN:=}",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -24,7 +24,7 @@ class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference/v1",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
|
@ -26,7 +26,7 @@ class GeminiConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.GEMINI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
MODEL_ENTRIES,
|
MODEL_ENTRIES,
|
||||||
|
litellm_provider_name="gemini",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key,
|
||||||
provider_data_api_key_field="gemini_api_key",
|
provider_data_api_key_field="gemini_api_key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,11 +10,11 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
)
|
)
|
||||||
|
|
||||||
LLM_MODEL_IDS = [
|
LLM_MODEL_IDS = [
|
||||||
"gemini/gemini-1.5-flash",
|
"gemini-1.5-flash",
|
||||||
"gemini/gemini-1.5-pro",
|
"gemini-1.5-pro",
|
||||||
"gemini/gemini-2.0-flash",
|
"gemini-2.0-flash",
|
||||||
"gemini/gemini-2.5-flash",
|
"gemini-2.5-flash",
|
||||||
"gemini/gemini-2.5-pro",
|
"gemini-2.5-pro",
|
||||||
]
|
]
|
||||||
|
|
||||||
SAFETY_MODELS_ENTRIES = []
|
SAFETY_MODELS_ENTRIES = []
|
||||||
|
@ -23,7 +23,7 @@ MODEL_ENTRIES = (
|
||||||
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
[ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
|
||||||
+ [
|
+ [
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="gemini/text-embedding-004",
|
provider_model_id="text-embedding-004",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
metadata={"embedding_dimension": 768, "context_length": 2048},
|
metadata={"embedding_dimension": 768, "context_length": 2048},
|
||||||
),
|
),
|
||||||
|
|
|
@ -32,7 +32,7 @@ class GroqConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.groq.com",
|
"url": "https://api.groq.com",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
|
@ -34,6 +34,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
|
litellm_provider_name="groq",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key,
|
||||||
provider_data_api_key_field="groq_api_key",
|
provider_data_api_key_field="groq_api_key",
|
||||||
)
|
)
|
||||||
|
@ -96,7 +97,7 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
tool_choice = "required"
|
tool_choice = "required"
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id.replace("groq/", ""),
|
model=model_obj.provider_resource_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
function_call=function_call,
|
function_call=function_call,
|
||||||
|
|
|
@ -14,19 +14,19 @@ SAFETY_MODELS_ENTRIES = []
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"groq/llama3-8b-8192",
|
"llama3-8b-8192",
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_model_entry(
|
build_model_entry(
|
||||||
"groq/llama-3.1-8b-instant",
|
"llama-3.1-8b-instant",
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"groq/llama3-70b-8192",
|
"llama3-70b-8192",
|
||||||
CoreModelId.llama3_70b_instruct.value,
|
CoreModelId.llama3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"groq/llama-3.3-70b-versatile",
|
"llama-3.3-70b-versatile",
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
# Groq only contains a preview version for llama-3.2-3b
|
# Groq only contains a preview version for llama-3.2-3b
|
||||||
|
@ -34,23 +34,15 @@ MODEL_ENTRIES = [
|
||||||
# to pass the test fixture
|
# to pass the test fixture
|
||||||
# TODO(aidand): Replace this with a stable model once Groq supports it
|
# TODO(aidand): Replace this with a stable model once Groq supports it
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"groq/llama-3.2-3b-preview",
|
"llama-3.2-3b-preview",
|
||||||
CoreModelId.llama3_2_3b_instruct.value,
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"groq/llama-4-scout-17b-16e-instruct",
|
"meta-llama/llama-4-scout-17b-16e-instruct",
|
||||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"groq/meta-llama/llama-4-scout-17b-16e-instruct",
|
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
||||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"groq/llama-4-maverick-17b-128e-instruct",
|
|
||||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"groq/meta-llama/llama-4-maverick-17b-128e-instruct",
|
|
||||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||||
),
|
),
|
||||||
] + SAFETY_MODELS_ENTRIES
|
] + SAFETY_MODELS_ENTRIES
|
||||||
|
|
|
@ -32,6 +32,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
|
litellm_provider_name="llama",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key,
|
||||||
provider_data_api_key_field="llama_api_key",
|
provider_data_api_key_field="llama_api_key",
|
||||||
openai_compat_api_base=config.openai_compat_api_base,
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
|
|
|
@ -166,7 +166,7 @@ class OllamaInferenceAdapter(
|
||||||
]
|
]
|
||||||
for m in response.models:
|
for m in response.models:
|
||||||
# kill embedding models since we don't know dimensions for them
|
# kill embedding models since we don't know dimensions for them
|
||||||
if m.details.family in ["bert"]:
|
if "bert" in m.details.family:
|
||||||
continue
|
continue
|
||||||
models.append(
|
models.append(
|
||||||
Model(
|
Model(
|
||||||
|
@ -420,9 +420,6 @@ class OllamaInferenceAdapter(
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
|
|
||||||
if model.provider_resource_id is None:
|
|
||||||
raise ValueError("Model provider_resource_id cannot be None")
|
|
||||||
|
|
||||||
if model.model_type == ModelType.embedding:
|
if model.model_type == ModelType.embedding:
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
if model.provider_resource_id not in [m.model for m in response.models]:
|
if model.provider_resource_id not in [m.model for m in response.models]:
|
||||||
|
@ -433,9 +430,9 @@ class OllamaInferenceAdapter(
|
||||||
# - models not currently running are run by the ollama server as needed
|
# - models not currently running are run by the ollama server as needed
|
||||||
response = await self.client.list()
|
response = await self.client.list()
|
||||||
available_models = [m.model for m in response.models]
|
available_models = [m.model for m in response.models]
|
||||||
provider_resource_id = self.register_helper.get_provider_model_id(model.provider_resource_id)
|
|
||||||
if provider_resource_id is None:
|
|
||||||
provider_resource_id = model.provider_resource_id
|
provider_resource_id = model.provider_resource_id
|
||||||
|
assert provider_resource_id is not None # mypy
|
||||||
if provider_resource_id not in available_models:
|
if provider_resource_id not in available_models:
|
||||||
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
available_models_latest = [m.model.split(":latest")[0] for m in response.models]
|
||||||
if provider_resource_id in available_models_latest:
|
if provider_resource_id in available_models_latest:
|
||||||
|
@ -443,7 +440,9 @@ class OllamaInferenceAdapter(
|
||||||
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
f"Imprecise provider resource id was used but 'latest' is available in Ollama - using '{model.provider_resource_id}:latest'"
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
raise UnsupportedModelError(model.provider_resource_id, available_models)
|
raise UnsupportedModelError(provider_resource_id, available_models)
|
||||||
|
|
||||||
|
# mutating this should be considered an anti-pattern
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -26,7 +26,7 @@ class OpenAIConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.OPENAI_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,6 +45,7 @@ class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
MODEL_ENTRIES,
|
MODEL_ENTRIES,
|
||||||
|
litellm_provider_name="openai",
|
||||||
api_key_from_config=config.api_key,
|
api_key_from_config=config.api_key,
|
||||||
provider_data_api_key_field="openai_api_key",
|
provider_data_api_key_field="openai_api_key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -30,7 +30,7 @@ class SambaNovaImplConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.sambanova.ai/v1",
|
"url": "https://api.sambanova.ai/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
|
@ -9,49 +9,20 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
SAFETY_MODELS_ENTRIES = [
|
SAFETY_MODELS_ENTRIES = []
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Meta-Llama-Guard-3-8B",
|
|
||||||
CoreModelId.llama_guard_3_8b.value,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"sambanova/Meta-Llama-3.1-8B-Instruct",
|
"Meta-Llama-3.1-8B-Instruct",
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"sambanova/Meta-Llama-3.1-405B-Instruct",
|
"Meta-Llama-3.3-70B-Instruct",
|
||||||
CoreModelId.llama3_1_405b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Meta-Llama-3.2-1B-Instruct",
|
|
||||||
CoreModelId.llama3_2_1b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Meta-Llama-3.2-3B-Instruct",
|
|
||||||
CoreModelId.llama3_2_3b_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Meta-Llama-3.3-70B-Instruct",
|
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"sambanova/Llama-3.2-11B-Vision-Instruct",
|
"Llama-4-Maverick-17B-128E-Instruct",
|
||||||
CoreModelId.llama3_2_11b_vision_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Llama-3.2-90B-Vision-Instruct",
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Llama-4-Scout-17B-16E-Instruct",
|
|
||||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"sambanova/Llama-4-Maverick-17B-128E-Instruct",
|
|
||||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||||
),
|
),
|
||||||
] + SAFETY_MODELS_ENTRIES
|
] + SAFETY_MODELS_ENTRIES
|
||||||
|
|
|
@ -182,6 +182,7 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
model_entries=MODEL_ENTRIES,
|
model_entries=MODEL_ENTRIES,
|
||||||
|
litellm_provider_name="sambanova",
|
||||||
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
|
||||||
provider_data_api_key_field="sambanova_api_key",
|
provider_data_api_key_field="sambanova_api_key",
|
||||||
)
|
)
|
||||||
|
|
|
@ -19,7 +19,7 @@ class TGIImplConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.TGI_URL}",
|
url: str = "${env.TGI_URL:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -305,6 +305,8 @@ class _HfAdapter(
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
|
if not config.url:
|
||||||
|
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.url}")
|
||||||
self.client = AsyncInferenceClient(
|
self.client = AsyncInferenceClient(
|
||||||
model=config.url,
|
model=config.url,
|
||||||
|
|
|
@ -27,5 +27,5 @@ class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.together.xyz/v1",
|
"url": "https://api.together.xyz/v1",
|
||||||
"api_key": "${env.TOGETHER_API_KEY}",
|
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,15 +69,9 @@ MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
||||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||||
additional_aliases=[
|
|
||||||
"together/meta-llama/Llama-4-Scout-17B-16E-Instruct",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
"meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
||||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||||
additional_aliases=[
|
|
||||||
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
|
|
||||||
],
|
|
||||||
),
|
),
|
||||||
] + SAFETY_MODELS_ENTRIES
|
] + SAFETY_MODELS_ENTRIES
|
||||||
|
|
|
@ -299,7 +299,10 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
if not self.config.url:
|
||||||
|
raise ValueError(
|
||||||
|
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||||
|
)
|
||||||
|
|
||||||
async def should_refresh_models(self) -> bool:
|
async def should_refresh_models(self) -> bool:
|
||||||
return self.config.refresh_models
|
return self.config.refresh_models
|
||||||
|
@ -337,9 +340,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
HealthResponse: A dictionary containing the health status.
|
HealthResponse: A dictionary containing the health status.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not self.config.url:
|
|
||||||
return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set")
|
|
||||||
|
|
||||||
client = self._create_client() if self.client is None else self.client
|
client = self._create_client() if self.client is None else self.client
|
||||||
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
_ = [m async for m in client.models.list()] # Ensure the client is initialized
|
||||||
return HealthResponse(status=HealthStatus.OK)
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
|
@ -355,11 +355,6 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not self.config.url:
|
|
||||||
raise ValueError(
|
|
||||||
"You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)"
|
|
||||||
)
|
|
||||||
|
|
||||||
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||||
self.client = self._create_client()
|
self.client = self._create_client()
|
||||||
|
|
||||||
|
|
|
@ -30,7 +30,7 @@ class SambaNovaSafetyConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.sambanova.ai/v1",
|
"url": "https://api.sambanova.ai/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
|
|
|
@ -68,11 +68,14 @@ class LiteLLMOpenAIMixin(
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_entries,
|
model_entries,
|
||||||
|
litellm_provider_name: str,
|
||||||
api_key_from_config: str | None,
|
api_key_from_config: str | None,
|
||||||
provider_data_api_key_field: str,
|
provider_data_api_key_field: str,
|
||||||
openai_compat_api_base: str | None = None,
|
openai_compat_api_base: str | None = None,
|
||||||
):
|
):
|
||||||
ModelRegistryHelper.__init__(self, model_entries)
|
ModelRegistryHelper.__init__(self, model_entries)
|
||||||
|
|
||||||
|
self.litellm_provider_name = litellm_provider_name
|
||||||
self.api_key_from_config = api_key_from_config
|
self.api_key_from_config = api_key_from_config
|
||||||
self.provider_data_api_key_field = provider_data_api_key_field
|
self.provider_data_api_key_field = provider_data_api_key_field
|
||||||
self.api_base = openai_compat_api_base
|
self.api_base = openai_compat_api_base
|
||||||
|
@ -91,7 +94,11 @@ class LiteLLMOpenAIMixin(
|
||||||
def get_litellm_model_name(self, model_id: str) -> str:
|
def get_litellm_model_name(self, model_id: str) -> str:
|
||||||
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
# users may be using openai/ prefix in their model names. the openai/models.py did this by default.
|
||||||
# model_id.startswith("openai/") is for backwards compatibility.
|
# model_id.startswith("openai/") is for backwards compatibility.
|
||||||
return "openai/" + model_id if self.is_openai_compat and not model_id.startswith("openai/") else model_id
|
return (
|
||||||
|
f"{self.litellm_provider_name}/{model_id}"
|
||||||
|
if self.is_openai_compat and not model_id.startswith(self.litellm_provider_name)
|
||||||
|
else model_id
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -50,7 +50,8 @@ def build_hf_repo_model_entry(
|
||||||
additional_aliases: list[str] | None = None,
|
additional_aliases: list[str] | None = None,
|
||||||
) -> ProviderModelEntry:
|
) -> ProviderModelEntry:
|
||||||
aliases = [
|
aliases = [
|
||||||
get_huggingface_repo(model_descriptor),
|
# NOTE: avoid HF aliases because they _cannot_ be unique across providers
|
||||||
|
# get_huggingface_repo(model_descriptor),
|
||||||
]
|
]
|
||||||
if additional_aliases:
|
if additional_aliases:
|
||||||
aliases.extend(additional_aliases)
|
aliases.extend(additional_aliases)
|
||||||
|
@ -75,7 +76,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
__provider_id__: str
|
__provider_id__: str
|
||||||
|
|
||||||
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
def __init__(self, model_entries: list[ProviderModelEntry], allowed_models: list[str] | None = None):
|
||||||
|
self.model_entries = model_entries
|
||||||
self.allowed_models = allowed_models
|
self.allowed_models = allowed_models
|
||||||
|
|
||||||
self.alias_to_provider_id_map = {}
|
self.alias_to_provider_id_map = {}
|
||||||
self.provider_id_to_llama_model_map = {}
|
self.provider_id_to_llama_model_map = {}
|
||||||
for entry in model_entries:
|
for entry in model_entries:
|
||||||
|
@ -98,7 +101,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
continue
|
continue
|
||||||
models.append(
|
models.append(
|
||||||
Model(
|
Model(
|
||||||
model_id=id,
|
identifier=id,
|
||||||
provider_resource_id=entry.provider_model_id,
|
provider_resource_id=entry.provider_model_id,
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
metadata=entry.metadata,
|
metadata=entry.metadata,
|
||||||
|
@ -185,8 +188,8 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return model
|
return model
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
# TODO: should we block unregistering base supported provider model IDs?
|
# model_id is the identifier, not the provider_resource_id
|
||||||
if model_id not in self.alias_to_provider_id_map:
|
# unfortunately, this ID can be of the form provider_id/model_id which
|
||||||
raise ValueError(f"Model id '{model_id}' is not registered.")
|
# we never registered. TODO: fix this by significantly rewriting
|
||||||
|
# registration and registry helper
|
||||||
del self.alias_to_provider_id_map[model_id]
|
pass
|
||||||
|
|
|
@ -7,21 +7,15 @@ distribution_spec:
|
||||||
- provider_type: remote::ollama
|
- provider_type: remote::ollama
|
||||||
- provider_type: remote::vllm
|
- provider_type: remote::vllm
|
||||||
- provider_type: remote::tgi
|
- provider_type: remote::tgi
|
||||||
- provider_type: remote::hf::serverless
|
|
||||||
- provider_type: remote::hf::endpoint
|
|
||||||
- provider_type: remote::fireworks
|
- provider_type: remote::fireworks
|
||||||
- provider_type: remote::together
|
- provider_type: remote::together
|
||||||
- provider_type: remote::bedrock
|
- provider_type: remote::bedrock
|
||||||
- provider_type: remote::databricks
|
|
||||||
- provider_type: remote::nvidia
|
- provider_type: remote::nvidia
|
||||||
- provider_type: remote::runpod
|
|
||||||
- provider_type: remote::openai
|
- provider_type: remote::openai
|
||||||
- provider_type: remote::anthropic
|
- provider_type: remote::anthropic
|
||||||
- provider_type: remote::gemini
|
- provider_type: remote::gemini
|
||||||
- provider_type: remote::groq
|
- provider_type: remote::groq
|
||||||
- provider_type: remote::llama-openai-compat
|
|
||||||
- provider_type: remote::sambanova
|
- provider_type: remote::sambanova
|
||||||
- provider_type: remote::passthrough
|
|
||||||
- provider_type: inline::sentence-transformers
|
- provider_type: inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_type: inline::faiss
|
- provider_type: inline::faiss
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -89,101 +89,51 @@ models:
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama3-8b-instruct
|
provider_model_id: meta/llama3-8b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3-8B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama3-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama3-70b-instruct
|
model_id: meta/llama3-70b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama3-70b-instruct
|
provider_model_id: meta/llama3-70b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3-70B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.1-8b-instruct
|
model_id: meta/llama-3.1-8b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.1-8b-instruct
|
provider_model_id: meta/llama-3.1-8b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.1-8b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.1-70b-instruct
|
model_id: meta/llama-3.1-70b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.1-70b-instruct
|
provider_model_id: meta/llama-3.1-70b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.1-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.1-405b-instruct
|
model_id: meta/llama-3.1-405b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.1-405b-instruct
|
provider_model_id: meta/llama-3.1-405b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.1-405b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.2-1b-instruct
|
model_id: meta/llama-3.2-1b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.2-1b-instruct
|
provider_model_id: meta/llama-3.2-1b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-1b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.2-3b-instruct
|
model_id: meta/llama-3.2-3b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.2-3b-instruct
|
provider_model_id: meta/llama-3.2-3b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-3b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.2-11b-vision-instruct
|
model_id: meta/llama-3.2-11b-vision-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-11b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.2-90b-vision-instruct
|
model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta/llama-3.3-70b-instruct
|
model_id: meta/llama-3.3-70b-instruct
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.3-70b-instruct
|
provider_model_id: meta/llama-3.3-70b-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- metadata: {}
|
|
||||||
model_id: meta-llama/Llama-3.3-70B-Instruct
|
|
||||||
provider_id: nvidia
|
|
||||||
provider_model_id: meta/llama-3.3-70b-instruct
|
|
||||||
model_type: llm
|
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 2048
|
embedding_dimension: 2048
|
||||||
context_length: 8192
|
context_length: 8192
|
||||||
|
|
|
@ -33,7 +33,7 @@ providers:
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: sqlite-vec
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::sqlite-vec
|
provider_type: inline::sqlite-vec
|
||||||
|
|
|
@ -7,21 +7,15 @@ distribution_spec:
|
||||||
- provider_type: remote::ollama
|
- provider_type: remote::ollama
|
||||||
- provider_type: remote::vllm
|
- provider_type: remote::vllm
|
||||||
- provider_type: remote::tgi
|
- provider_type: remote::tgi
|
||||||
- provider_type: remote::hf::serverless
|
|
||||||
- provider_type: remote::hf::endpoint
|
|
||||||
- provider_type: remote::fireworks
|
- provider_type: remote::fireworks
|
||||||
- provider_type: remote::together
|
- provider_type: remote::together
|
||||||
- provider_type: remote::bedrock
|
- provider_type: remote::bedrock
|
||||||
- provider_type: remote::databricks
|
|
||||||
- provider_type: remote::nvidia
|
- provider_type: remote::nvidia
|
||||||
- provider_type: remote::runpod
|
|
||||||
- provider_type: remote::openai
|
- provider_type: remote::openai
|
||||||
- provider_type: remote::anthropic
|
- provider_type: remote::anthropic
|
||||||
- provider_type: remote::gemini
|
- provider_type: remote::gemini
|
||||||
- provider_type: remote::groq
|
- provider_type: remote::groq
|
||||||
- provider_type: remote::llama-openai-compat
|
|
||||||
- provider_type: remote::sambanova
|
- provider_type: remote::sambanova
|
||||||
- provider_type: remote::passthrough
|
|
||||||
- provider_type: inline::sentence-transformers
|
- provider_type: inline::sentence-transformers
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_type: inline::faiss
|
- provider_type: inline::faiss
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -7,20 +7,19 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
BuildProvider,
|
BuildProvider,
|
||||||
ModelInput,
|
|
||||||
Provider,
|
Provider,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
|
ShieldInput,
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
from llama_stack.providers.datatypes import RemoteProviderSpec
|
||||||
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
|
||||||
from llama_stack.providers.inline.inference.sentence_transformers import (
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
SentenceTransformersInferenceConfig,
|
SentenceTransformersInferenceConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.inline.post_training.huggingface import HuggingFacePostTrainingConfig
|
|
||||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||||
from llama_stack.providers.inline.vector_io.milvus.config import (
|
from llama_stack.providers.inline.vector_io.milvus.config import (
|
||||||
MilvusVectorIOConfig,
|
MilvusVectorIOConfig,
|
||||||
|
@ -29,117 +28,17 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.config import (
|
||||||
SQLiteVectorIOConfig,
|
SQLiteVectorIOConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.registry.inference import available_providers
|
from llama_stack.providers.registry.inference import available_providers
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
|
||||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
|
||||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
|
||||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
|
||||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
|
||||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
|
||||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
|
||||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
|
||||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
|
||||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
|
||||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
|
||||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
|
||||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import ProviderModelEntry
|
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig
|
||||||
from llama_stack.templates.template import (
|
from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
get_model_registry,
|
|
||||||
get_shield_registry,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
|
||||||
"""Get model entries for a specific provider type."""
|
|
||||||
model_entries_map = {
|
|
||||||
"openai": OPENAI_MODEL_ENTRIES,
|
|
||||||
"fireworks": FIREWORKS_MODEL_ENTRIES,
|
|
||||||
"together": TOGETHER_MODEL_ENTRIES,
|
|
||||||
"anthropic": ANTHROPIC_MODEL_ENTRIES,
|
|
||||||
"gemini": GEMINI_MODEL_ENTRIES,
|
|
||||||
"groq": GROQ_MODEL_ENTRIES,
|
|
||||||
"sambanova": SAMBANOVA_MODEL_ENTRIES,
|
|
||||||
"cerebras": CEREBRAS_MODEL_ENTRIES,
|
|
||||||
"bedrock": BEDROCK_MODEL_ENTRIES,
|
|
||||||
"databricks": DATABRICKS_MODEL_ENTRIES,
|
|
||||||
"nvidia": NVIDIA_MODEL_ENTRIES,
|
|
||||||
"runpod": RUNPOD_MODEL_ENTRIES,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Special handling for providers with dynamic model entries
|
|
||||||
if provider_type == "ollama":
|
|
||||||
return [
|
|
||||||
ProviderModelEntry(
|
|
||||||
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
ProviderModelEntry(
|
|
||||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
ProviderModelEntry(
|
|
||||||
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
|
|
||||||
model_type=ModelType.embedding,
|
|
||||||
metadata={
|
|
||||||
"embedding_dimension": "${env.OLLAMA_EMBEDDING_DIMENSION:=384}",
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
|
||||||
elif provider_type == "vllm":
|
|
||||||
return [
|
|
||||||
ProviderModelEntry(
|
|
||||||
provider_model_id="${env.VLLM_INFERENCE_MODEL:=__disabled__}",
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
return model_entries_map.get(provider_type, [])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
|
||||||
"""Get model entries for a specific provider type."""
|
|
||||||
safety_model_entries_map = {
|
|
||||||
"ollama": [
|
|
||||||
ProviderModelEntry(
|
|
||||||
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
|
||||||
model_type=ModelType.llm,
|
|
||||||
),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
return safety_model_entries_map.get(provider_type, [])
|
|
||||||
|
|
||||||
|
|
||||||
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||||
"""Get configuration for a provider using its adapter's config class."""
|
"""Get configuration for a provider using its adapter's config class."""
|
||||||
config_class = instantiate_class_type(provider_spec.config_class)
|
config_class = instantiate_class_type(provider_spec.config_class)
|
||||||
|
@ -150,40 +49,48 @@ def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[ProviderModelEntry]]]:
|
ENABLED_INFERENCE_PROVIDERS = [
|
||||||
all_providers = available_providers()
|
"ollama",
|
||||||
|
"vllm",
|
||||||
|
"tgi",
|
||||||
|
"fireworks",
|
||||||
|
"together",
|
||||||
|
"gemini",
|
||||||
|
"groq",
|
||||||
|
"sambanova",
|
||||||
|
"anthropic",
|
||||||
|
"openai",
|
||||||
|
"cerebras",
|
||||||
|
"nvidia",
|
||||||
|
"bedrock",
|
||||||
|
]
|
||||||
|
|
||||||
# Filter out inline providers and watsonx - the starter distro only exposes remote providers
|
INFERENCE_PROVIDER_IDS = {
|
||||||
|
"vllm": "${env.VLLM_URL:+vllm}",
|
||||||
|
"tgi": "${env.TGI_URL:+tgi}",
|
||||||
|
"cerebras": "${env.CEREBRAS_API_KEY:+cerebras}",
|
||||||
|
"nvidia": "${env.NVIDIA_API_KEY:+nvidia}",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_remote_inference_providers() -> list[Provider]:
|
||||||
|
# Filter out inline providers and some others - the starter distro only exposes remote providers
|
||||||
remote_providers = [
|
remote_providers = [
|
||||||
provider
|
provider
|
||||||
for provider in all_providers
|
for provider in available_providers()
|
||||||
# TODO: re-add once the Python 3.13 issue is fixed
|
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
|
||||||
# discussion: https://github.com/meta-llama/llama-stack/pull/2327#discussion_r2156883828
|
|
||||||
if hasattr(provider, "adapter") and provider.adapter.adapter_type != "watsonx"
|
|
||||||
]
|
]
|
||||||
|
|
||||||
providers = []
|
inference_providers = []
|
||||||
available_models = {}
|
|
||||||
|
|
||||||
for provider_spec in remote_providers:
|
for provider_spec in remote_providers:
|
||||||
provider_type = provider_spec.adapter.adapter_type
|
provider_type = provider_spec.adapter.adapter_type
|
||||||
|
|
||||||
# Build the environment variable name for enabling this provider
|
if provider_type in INFERENCE_PROVIDER_IDS:
|
||||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
provider_id = INFERENCE_PROVIDER_IDS[provider_type]
|
||||||
model_entries = _get_model_entries_for_provider(provider_type)
|
else:
|
||||||
|
provider_id = provider_type.replace("-", "_").replace("::", "_")
|
||||||
config = _get_config_for_provider(provider_spec)
|
config = _get_config_for_provider(provider_spec)
|
||||||
providers.append(
|
|
||||||
(
|
|
||||||
f"${{env.{env_var}:=__disabled__}}",
|
|
||||||
provider_type,
|
|
||||||
model_entries,
|
|
||||||
config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
available_models[f"${{env.{env_var}:=__disabled__}}"] = model_entries
|
|
||||||
|
|
||||||
inference_providers = []
|
|
||||||
for provider_id, provider_type, model_entries, config in providers:
|
|
||||||
inference_providers.append(
|
inference_providers.append(
|
||||||
Provider(
|
Provider(
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
|
@ -191,31 +98,13 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
available_models[provider_id] = model_entries
|
return inference_providers
|
||||||
return inference_providers, available_models
|
|
||||||
|
|
||||||
|
|
||||||
# build a list of shields for all possible providers
|
|
||||||
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
|
||||||
available_models = {}
|
|
||||||
for provider in providers:
|
|
||||||
provider_type = provider.provider_type.split("::")[1]
|
|
||||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
|
||||||
if len(safety_model_entries) == 0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
|
||||||
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
|
||||||
|
|
||||||
available_models[provider_id] = safety_model_entries
|
|
||||||
|
|
||||||
return available_models
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
remote_inference_providers, available_models = get_remote_inference_providers()
|
remote_inference_providers = get_remote_inference_providers()
|
||||||
name = "starter"
|
name = "starter"
|
||||||
# For build config, use BuildProvider with only provider_type and module
|
|
||||||
providers = {
|
providers = {
|
||||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||||
+ [BuildProvider(provider_type="inline::sentence-transformers")],
|
+ [BuildProvider(provider_type="inline::sentence-transformers")],
|
||||||
|
@ -254,15 +143,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
)
|
)
|
||||||
embedding_provider = Provider(
|
embedding_provider = Provider(
|
||||||
provider_id="${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}",
|
provider_id="sentence-transformers",
|
||||||
provider_type="inline::sentence-transformers",
|
provider_type="inline::sentence-transformers",
|
||||||
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
post_training_provider = Provider(
|
|
||||||
provider_id="huggingface",
|
|
||||||
provider_type="inline::huggingface",
|
|
||||||
config=HuggingFacePostTrainingConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
|
||||||
)
|
|
||||||
default_tool_groups = [
|
default_tool_groups = [
|
||||||
ToolGroupInput(
|
ToolGroupInput(
|
||||||
toolgroup_id="builtin::websearch",
|
toolgroup_id="builtin::websearch",
|
||||||
|
@ -273,19 +157,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="rag-runtime",
|
provider_id="rag-runtime",
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
embedding_model = ModelInput(
|
default_shields = [
|
||||||
model_id="all-MiniLM-L6-v2",
|
# if the
|
||||||
provider_id=embedding_provider.provider_id,
|
ShieldInput(
|
||||||
model_type=ModelType.embedding,
|
shield_id="llama-guard",
|
||||||
metadata={
|
provider_id="${env.SAFETY_MODEL:+llama-guard}",
|
||||||
"embedding_dimension": 384,
|
provider_shield_id="${env.SAFETY_MODEL:=}",
|
||||||
},
|
),
|
||||||
)
|
]
|
||||||
|
|
||||||
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
|
||||||
|
|
||||||
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
|
||||||
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -294,7 +173,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=None,
|
template_path=None,
|
||||||
providers=providers,
|
providers=providers,
|
||||||
available_models_by_provider=available_models,
|
|
||||||
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
additional_pip_packages=PostgresSqlStoreConfig.pip_packages(),
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
|
@ -302,22 +180,22 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"inference": remote_inference_providers + [embedding_provider],
|
"inference": remote_inference_providers + [embedding_provider],
|
||||||
"vector_io": [
|
"vector_io": [
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_FAISS:=faiss}",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_SQLITE_VEC:=__disabled__}",
|
provider_id="sqlite-vec",
|
||||||
provider_type="inline::sqlite-vec",
|
provider_type="inline::sqlite-vec",
|
||||||
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=SQLiteVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_MILVUS:=__disabled__}",
|
provider_id="${env.MILVUS_URL:+milvus}",
|
||||||
provider_type="inline::milvus",
|
provider_type="inline::milvus",
|
||||||
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
config=MilvusVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"),
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_CHROMADB:=__disabled__}",
|
provider_id="${env.CHROMADB_URL:+chromadb}",
|
||||||
provider_type="remote::chromadb",
|
provider_type="remote::chromadb",
|
||||||
config=ChromaVectorIOConfig.sample_run_config(
|
config=ChromaVectorIOConfig.sample_run_config(
|
||||||
f"~/.llama/distributions/{name}/",
|
f"~/.llama/distributions/{name}/",
|
||||||
|
@ -325,7 +203,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
provider_id="${env.PGVECTOR_DB:+pgvector}",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
config=PGVectorVectorIOConfig.sample_run_config(
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
f"~/.llama/distributions/{name}",
|
f"~/.llama/distributions/{name}",
|
||||||
|
@ -336,12 +214,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
"files": [files_provider],
|
"files": [files_provider],
|
||||||
"post_training": [post_training_provider],
|
|
||||||
},
|
},
|
||||||
default_models=[embedding_model] + default_models,
|
default_models=[],
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
# TODO: add a way to enable/disable shields on the fly
|
default_shields=default_shields,
|
||||||
default_shields=shields,
|
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
run_config_env_vars={
|
run_config_env_vars={
|
||||||
|
@ -385,17 +261,5 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"http://localhost:11434",
|
"http://localhost:11434",
|
||||||
"Ollama URL",
|
"Ollama URL",
|
||||||
),
|
),
|
||||||
"OLLAMA_INFERENCE_MODEL": (
|
|
||||||
"",
|
|
||||||
"Optional Ollama Inference Model to register on startup",
|
|
||||||
),
|
|
||||||
"OLLAMA_EMBEDDING_MODEL": (
|
|
||||||
"",
|
|
||||||
"Optional Ollama Embedding Model to register on startup",
|
|
||||||
),
|
|
||||||
"OLLAMA_EMBEDDING_DIMENSION": (
|
|
||||||
"384",
|
|
||||||
"Ollama Embedding Dimension",
|
|
||||||
),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
|
@ -25,7 +25,7 @@ dependencies = [
|
||||||
"fastapi>=0.115.0,<1.0", # server
|
"fastapi>=0.115.0,<1.0", # server
|
||||||
"fire", # for MCP in LLS client
|
"fire", # for MCP in LLS client
|
||||||
"httpx",
|
"httpx",
|
||||||
"huggingface-hub>=0.30.0,<1.0",
|
"huggingface-hub>=0.34.0,<1.0",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
"llama-stack-client>=0.2.15",
|
"llama-stack-client>=0.2.15",
|
||||||
|
|
|
@ -86,7 +86,7 @@ httpx==0.28.1
|
||||||
# llama-stack
|
# llama-stack
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
huggingface-hub==0.33.0
|
huggingface-hub==0.34.1
|
||||||
# via llama-stack
|
# via llama-stack
|
||||||
idna==3.10
|
idna==3.10
|
||||||
# via
|
# via
|
||||||
|
|
|
@ -222,9 +222,7 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
--network llama-net \
|
--network llama-net \
|
||||||
-p "${PORT}:${PORT}" \
|
-p "${PORT}:${PORT}" \
|
||||||
"${SERVER_IMAGE}" --port "${PORT}" \
|
"${SERVER_IMAGE}" --port "${PORT}" \
|
||||||
--env OLLAMA_INFERENCE_MODEL="${MODEL_ALIAS}" \
|
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}")
|
||||||
--env OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
|
||||||
--env ENABLE_OLLAMA=ollama)
|
|
||||||
|
|
||||||
log "🦙 Starting Llama Stack..."
|
log "🦙 Starting Llama Stack..."
|
||||||
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
if ! execute_with_log $ENGINE "${cmd[@]}"; then
|
||||||
|
|
|
@ -502,7 +502,7 @@ async def test_models_source_interaction_preserves_default(cached_disk_dist_regi
|
||||||
|
|
||||||
# Find the user model and provider model
|
# Find the user model and provider model
|
||||||
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
user_model = next((m for m in models.data if m.identifier == "my-custom-alias"), None)
|
||||||
provider_model = next((m for m in models.data if m.identifier == "different-model"), None)
|
provider_model = next((m for m in models.data if m.identifier == "test_provider/different-model"), None)
|
||||||
|
|
||||||
assert user_model is not None
|
assert user_model is not None
|
||||||
assert user_model.source == RegistryEntrySource.via_register_api
|
assert user_model.source == RegistryEntrySource.via_register_api
|
||||||
|
@ -558,12 +558,12 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
||||||
|
|
||||||
identifiers = {m.identifier for m in models.data}
|
identifiers = {m.identifier for m in models.data}
|
||||||
assert "test_provider/user-model" in identifiers # User model preserved
|
assert "test_provider/user-model" in identifiers # User model preserved
|
||||||
assert "provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
assert "test_provider/provider-model-new" in identifiers # New provider model (uses provider's identifier)
|
||||||
assert "provider-model-old" not in identifiers # Old provider model removed
|
assert "test_provider/provider-model-old" not in identifiers # Old provider model removed
|
||||||
|
|
||||||
# Verify sources are correct
|
# Verify sources are correct
|
||||||
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
user_model = next((m for m in models.data if m.identifier == "test_provider/user-model"), None)
|
||||||
provider_model = next((m for m in models.data if m.identifier == "provider-model-new"), None)
|
provider_model = next((m for m in models.data if m.identifier == "test_provider/provider-model-new"), None)
|
||||||
|
|
||||||
assert user_model.source == RegistryEntrySource.via_register_api
|
assert user_model.source == RegistryEntrySource.via_register_api
|
||||||
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
assert provider_model.source == RegistryEntrySource.listed_from_provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue