mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat!: standardize base_url for inference
Completes #3732 by removing runtime URL transformations and requiring users to provide full URLs in configuration. All providers now use 'base_url' consistently and respect the exact URL provided without appending paths like /v1 or /openai/v1 at runtime. Add unit test to enforce URL standardization across remote inference providers (verifies all use 'base_url' field with HttpUrl | None type) BREAKING CHANGE: Users must update configs to include full URL paths (e.g., http://localhost:11434/v1 instead of http://localhost:11434). Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
7093978754
commit
7a9c32f737
67 changed files with 282 additions and 227 deletions
|
|
@ -24,7 +24,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `api_base` | `HttpUrl` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
| `base_url` | `HttpUrl \| None` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1) |
|
||||||
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
||||||
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
||||||
|
|
||||||
|
|
@ -32,7 +32,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `base_url` | `str` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
| `base_url` | `HttpUrl \| None` | No | https://api.cerebras.ai/v1 | Base URL for the Cerebras API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Databricks inference provider for running models on Databricks' unified analytic
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_token` | `SecretStr \| None` | No | | The Databricks API token |
|
| `api_token` | `SecretStr \| None` | No | | The Databricks API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the Databricks model serving endpoint (should include /serving-endpoints path) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.DATABRICKS_HOST:=}
|
base_url: ${env.DATABRICKS_HOST:=}
|
||||||
api_token: ${env.DATABRICKS_TOKEN:=}
|
api_token: ${env.DATABRICKS_TOKEN:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.groq.com | The URL for the Groq AI server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.groq.com/openai/v1 | The URL for the Groq AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `openai_compat_api_base` | `str` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
base_url: https://api.llama.com/compat/v1/
|
||||||
api_key: ${env.LLAMA_API_KEY}
|
api_key: ${env.LLAMA_API_KEY}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,13 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `base_url` | `HttpUrl \| None` | No | https://integrate.api.nvidia.com/v1 | A base url for accessing the NVIDIA NIM |
|
||||||
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
||||||
| `append_api_version` | `bool` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
|
||||||
| `rerank_model_to_url` | `dict[str, str]` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
| `rerank_model_to_url` | `dict[str, str]` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,10 @@ Ollama inference provider for running local models through the Ollama runtime.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str` | No | http://localhost:11434 | |
|
| `base_url` | `HttpUrl \| None` | No | http://localhost:11434/v1 | |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `base_url` | `str` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
| `base_url` | `HttpUrl \| None` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Passthrough inference provider for connecting to any external inference service
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | | The URL for the passthrough endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the passthrough endpoint |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.PASSTHROUGH_URL}
|
base_url: ${env.PASSTHROUGH_URL}
|
||||||
api_key: ${env.PASSTHROUGH_API_KEY}
|
api_key: ${env.PASSTHROUGH_API_KEY}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_token` | `SecretStr \| None` | No | | The API token |
|
| `api_token` | `SecretStr \| None` | No | | The API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.RUNPOD_URL:=}
|
base_url: ${env.RUNPOD_URL:=}
|
||||||
api_token: ${env.RUNPOD_API_TOKEN}
|
api_token: ${env.RUNPOD_API_TOKEN}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,10 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str` | No | | The URL for the TGI serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the TGI serving endpoint (should include /v1 path) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Together AI inference provider for open-source models and collaborative AI devel
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_token` | `SecretStr \| None` | No | | The API token |
|
| `api_token` | `SecretStr \| None` | No | | The API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||||
| `max_tokens` | `int` | No | 4096 | Maximum number of tokens to generate. |
|
| `max_tokens` | `int` | No | 4096 | Maximum number of tokens to generate. |
|
||||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
| `base_url` | `HttpUrl \| None` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||||
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||||
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -287,9 +287,9 @@ start_container() {
|
||||||
# On macOS/Windows, use host.docker.internal to reach host from container
|
# On macOS/Windows, use host.docker.internal to reach host from container
|
||||||
# On Linux with --network host, use localhost
|
# On Linux with --network host, use localhost
|
||||||
if [[ "$(uname)" == "Darwin" ]] || [[ "$(uname)" == *"MINGW"* ]]; then
|
if [[ "$(uname)" == "Darwin" ]] || [[ "$(uname)" == *"MINGW"* ]]; then
|
||||||
OLLAMA_URL="${OLLAMA_URL:-http://host.docker.internal:11434}"
|
OLLAMA_URL="${OLLAMA_URL:-http://host.docker.internal:11434/v1}"
|
||||||
else
|
else
|
||||||
OLLAMA_URL="${OLLAMA_URL:-http://localhost:11434}"
|
OLLAMA_URL="${OLLAMA_URL:-http://localhost:11434/v1}"
|
||||||
fi
|
fi
|
||||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OLLAMA_URL=$OLLAMA_URL"
|
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OLLAMA_URL=$OLLAMA_URL"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -640,7 +640,7 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
--network llama-net \
|
--network llama-net \
|
||||||
-p "${PORT}:${PORT}" \
|
-p "${PORT}:${PORT}" \
|
||||||
"${server_env_opts[@]}" \
|
"${server_env_opts[@]}" \
|
||||||
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}/v1" \
|
||||||
"${SERVER_IMAGE}" --port "${PORT}")
|
"${SERVER_IMAGE}" --port "${PORT}")
|
||||||
|
|
||||||
log "🦙 Starting Llama Stack..."
|
log "🦙 Starting Llama Stack..."
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ providers:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ providers:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,12 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: sqlite-vec
|
- provider_id: sqlite-vec
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ providers:
|
||||||
- provider_id: vllm-inference
|
- provider_id: vllm-inference
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ providers:
|
||||||
- provider_id: watsonx
|
- provider_id: watsonx
|
||||||
provider_type: remote::watsonx
|
provider_type: remote::watsonx
|
||||||
config:
|
config:
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
vector_io:
|
vector_io:
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import AzureConfig
|
from .config import AzureConfig
|
||||||
|
|
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
Returns the Azure API base URL from the configuration.
|
Returns the Azure API base URL from the configuration.
|
||||||
"""
|
"""
|
||||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AzureConfig(RemoteInferenceProviderConfig):
|
class AzureConfig(RemoteInferenceProviderConfig):
|
||||||
api_base: HttpUrl = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
default=None,
|
||||||
|
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
|
||||||
)
|
)
|
||||||
api_version: str | None = Field(
|
api_version: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||||
|
|
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
base_url: str = "${env.AZURE_API_BASE:=}",
|
||||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"api_base": api_base,
|
"base_url": base_url,
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"api_type": api_type,
|
"api_type": api_type,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
|
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "cerebras_api_key"
|
provider_data_api_key_field: str = "cerebras_api_key"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return urljoin(self.config.base_url, "v1")
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,12 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
|
||||||
|
|
||||||
|
|
||||||
class CerebrasProviderDataValidator(BaseModel):
|
class CerebrasProviderDataValidator(BaseModel):
|
||||||
|
|
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||||
base_url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
|
||||||
description="Base URL for the Cerebras API",
|
description="Base URL for the Cerebras API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Databricks model serving endpoint",
|
description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
|
||||||
)
|
)
|
||||||
auth_credential: SecretStr | None = Field(
|
auth_credential: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.DATABRICKS_HOST:=}",
|
base_url: str = "${env.DATABRICKS_HOST:=}",
|
||||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"api_token": api_token,
|
"api_token": api_token,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f"{self.config.url}/serving-endpoints"
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
# Filter out None values from endpoint names
|
# Filter out None values from endpoint names
|
||||||
api_token = self._get_api_key_from_config_or_provider_data()
|
api_token = self._get_api_key_from_config_or_provider_data()
|
||||||
|
# WorkspaceClient expects base host without /serving-endpoints suffix
|
||||||
|
base_url_str = str(self.config.base_url)
|
||||||
|
if base_url_str.endswith("/serving-endpoints"):
|
||||||
|
host = base_url_str[:-18] # Remove '/serving-endpoints'
|
||||||
|
else:
|
||||||
|
host = base_url_str
|
||||||
return [
|
return [
|
||||||
endpoint.name # type: ignore[misc]
|
endpoint.name # type: ignore[misc]
|
||||||
for endpoint in WorkspaceClient(
|
for endpoint in WorkspaceClient(
|
||||||
host=self.config.url, token=api_token
|
host=host, token=api_token
|
||||||
).serving_endpoints.list() # TODO: this is not async
|
).serving_endpoints.list() # TODO: this is not async
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.fireworks.ai/inference/v1",
|
default=HttpUrl("https://api.fireworks.ai/inference/v1"),
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference/v1",
|
"base_url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "fireworks_api_key"
|
provider_data_api_key_field: str = "fireworks_api_key"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return "https://api.fireworks.ai/inference/v1"
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GroqConfig(RemoteInferenceProviderConfig):
|
class GroqConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.groq.com",
|
default=HttpUrl("https://api.groq.com/openai/v1"),
|
||||||
description="The URL for the Groq AI server",
|
description="The URL for the Groq AI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.groq.com",
|
"base_url": "https://api.groq.com/openai/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "groq_api_key"
|
provider_data_api_key_field: str = "groq_api_key"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f"{self.config.url}/openai/v1"
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||||
openai_compat_api_base: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.llama.com/compat/v1/",
|
default=HttpUrl("https://api.llama.com/compat/v1/"),
|
||||||
description="The URL for the Llama API server",
|
description="The URL for the Llama API server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
"base_url": "https://api.llama.com/compat/v1/",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The Llama API base URL
|
:return: The Llama API base URL
|
||||||
"""
|
"""
|
||||||
return self.config.openai_compat_api_base
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
|
||||||
description="A base url for accessing the NVIDIA NIM",
|
description="A base url for accessing the NVIDIA NIM",
|
||||||
)
|
)
|
||||||
timeout: int = Field(
|
timeout: int = Field(
|
||||||
default=60,
|
default=60,
|
||||||
description="Timeout for the HTTP requests",
|
description="Timeout for the HTTP requests",
|
||||||
)
|
)
|
||||||
append_api_version: bool = Field(
|
|
||||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
|
||||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
|
||||||
)
|
|
||||||
rerank_model_to_url: dict[str, str] = Field(
|
rerank_model_to_url: dict[str, str] = Field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||||
|
|
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
|
base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
|
||||||
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
||||||
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"append_api_version": append_api_version,
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
|
||||||
|
|
||||||
if _is_nvidia_hosted(self.config):
|
if _is_nvidia_hosted(self.config):
|
||||||
if not self.config.auth_credential:
|
if not self.config.auth_credential:
|
||||||
|
|
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The NVIDIA API base URL
|
:return: The NVIDIA API base URL
|
||||||
"""
|
"""
|
||||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -8,4 +8,4 @@ from . import NVIDIAConfig
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
return "integrate.api.nvidia.com" in config.url
|
return "integrate.api.nvidia.com" in str(config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,20 +6,22 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
url: str = DEFAULT_OLLAMA_URL
|
base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(
|
||||||
|
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
|
||||||
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
||||||
# ollama client attaches itself to the current event loop (sadly?)
|
# ollama client attaches itself to the current event loop (sadly?)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if loop not in self._clients:
|
if loop not in self._clients:
|
||||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
# Ollama client expects base URL without /v1 suffix
|
||||||
|
base_url_str = str(self.config.base_url)
|
||||||
|
if base_url_str.endswith("/v1"):
|
||||||
|
host = base_url_str[:-3]
|
||||||
|
else:
|
||||||
|
host = base_url_str
|
||||||
|
self._clients[loop] = AsyncOllamaClient(host=host)
|
||||||
return self._clients[loop]
|
return self._clients[loop]
|
||||||
|
|
||||||
def get_api_key(self):
|
def get_api_key(self):
|
||||||
return "NO KEY REQUIRED"
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return self.config.url.rstrip("/") + "/v1"
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
|
||||||
r = await self.health()
|
r = await self.health()
|
||||||
if r["status"] == HealthStatus.ERROR:
|
if r["status"] == HealthStatus.ERROR:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||||
base_url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.openai.com/v1",
|
default=HttpUrl("https://api.openai.com/v1"),
|
||||||
description="Base URL for OpenAI API",
|
description="Base URL for OpenAI API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
Returns the OpenAI API base URL from the configuration.
|
Returns the OpenAI API base URL from the configuration.
|
||||||
"""
|
"""
|
||||||
return self.config.base_url
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the passthrough endpoint",
|
description="The URL for the passthrough endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
|
|
||||||
def _get_passthrough_url(self) -> str:
|
def _get_passthrough_url(self) -> str:
|
||||||
"""Get the passthrough URL from config or provider data."""
|
"""Get the passthrough URL from config or provider data."""
|
||||||
if self.config.url is not None:
|
if self.config.base_url is not None:
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None:
|
if provider_data is None:
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Runpod model serving endpoint",
|
description="The URL for the Runpod model serving endpoint",
|
||||||
)
|
)
|
||||||
|
|
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.RUNPOD_URL:=}",
|
"base_url": "${env.RUNPOD_URL:=}",
|
||||||
"api_token": "${env.RUNPOD_API_TOKEN}",
|
"api_token": "${env.RUNPOD_API_TOKEN}",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""Get base URL for OpenAI client."""
|
"""Get base URL for OpenAI client."""
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.sambanova.ai/v1",
|
default=HttpUrl("https://api.sambanova.ai/v1"),
|
||||||
description="The URL for the SambaNova AI server",
|
description="The URL for the SambaNova AI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.sambanova.ai/v1",
|
"base_url": "https://api.sambanova.ai/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The SambaNova base URL
|
:return: The SambaNova base URL
|
||||||
"""
|
"""
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
|
||||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
description="The URL for the TGI serving endpoint",
|
default=None,
|
||||||
|
description="The URL for the TGI serving endpoint (should include /v1 path)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.TGI_URL:=}",
|
base_url: str = "${env.TGI_URL:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from pydantic import SecretStr
|
from pydantic import HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(OpenAIMixin):
|
class _HfAdapter(OpenAIMixin):
|
||||||
url: str
|
base_url: HttpUrl
|
||||||
api_key: SecretStr
|
api_key: SecretStr
|
||||||
|
|
||||||
hf_client: AsyncInferenceClient
|
hf_client: AsyncInferenceClient
|
||||||
|
|
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
|
||||||
return "NO KEY REQUIRED"
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return self.url
|
return self.base_url
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
return [self.model_id]
|
return [self.model_id]
|
||||||
|
|
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
if not config.url:
|
if not config.base_url:
|
||||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.base_url}")
|
||||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
# Extract base URL without /v1 for HF client initialization
|
||||||
|
base_url_str = str(config.base_url).rstrip("/")
|
||||||
|
if base_url_str.endswith("/v1"):
|
||||||
|
base_url_for_client = base_url_str[:-3]
|
||||||
|
else:
|
||||||
|
base_url_for_client = base_url_str
|
||||||
|
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
|
||||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
self.url = f"{config.url.rstrip('/')}/v1"
|
self.base_url = config.base_url
|
||||||
self.api_key = SecretStr("NO_KEY")
|
self.api_key = SecretStr("NO_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.together.xyz/v1",
|
default=HttpUrl("https://api.together.xyz/v1"),
|
||||||
description="The URL for the Together AI server",
|
description="The URL for the Together AI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.together.xyz/v1",
|
"base_url": "https://api.together.xyz/v1",
|
||||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from collections.abc import Iterable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from together import AsyncTogether # type: ignore[import-untyped]
|
from together import AsyncTogether # type: ignore[import-untyped]
|
||||||
from together.constants import BASE_URL # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
provider_data_api_key_field: str = "together_api_key"
|
provider_data_api_key_field: str = "together_api_key"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return BASE_URL
|
return str(self.config.base_url)
|
||||||
|
|
||||||
def _get_client(self) -> AsyncTogether:
|
def _get_client(self) -> AsyncTogether:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import Field, SecretStr, field_validator
|
from pydantic import Field, HttpUrl, SecretStr, field_validator
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the vLLM model serving endpoint",
|
description="The URL for the vLLM model serving endpoint",
|
||||||
)
|
)
|
||||||
|
|
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.VLLM_URL:=}",
|
base_url: str = "${env.VLLM_URL:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""Get the base URL from config."""
|
"""Get the base URL from config."""
|
||||||
if not self.config.url:
|
if not self.config.base_url:
|
||||||
raise ValueError("No base URL configured")
|
raise ValueError("No base URL configured")
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.config.url:
|
if not self.config.base_url:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
description="A base url for accessing the watsonx.ai",
|
description="A base url for accessing the watsonx.ai",
|
||||||
)
|
)
|
||||||
|
|
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
"base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||||
"api_key": "${env.WATSONX_API_KEY:=}",
|
"api_key": "${env.WATSONX_API_KEY:=}",
|
||||||
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
# Copied from OpenAIMixin
|
# Copied from OpenAIMixin
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
|
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
"""
|
"""
|
||||||
Retrieves foundation model specifications from the watsonx.ai API.
|
Retrieves foundation model specifications from the watsonx.ai API.
|
||||||
"""
|
"""
|
||||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||||
headers = {
|
headers = {
|
||||||
# Note that there is no authorization header. Listing models does not require authentication.
|
# Note that there is no authorization header. Listing models does not require authentication.
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
name="ollama",
|
name="ollama",
|
||||||
description="Local Ollama provider with text + safety models",
|
description="Local Ollama provider with text + safety models",
|
||||||
env={
|
env={
|
||||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
|
||||||
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
||||||
},
|
},
|
||||||
defaults={
|
defaults={
|
||||||
|
|
@ -64,7 +64,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
name="ollama",
|
name="ollama",
|
||||||
description="Local Ollama provider with a vision model",
|
description="Local Ollama provider with a vision model",
|
||||||
env={
|
env={
|
||||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
|
||||||
},
|
},
|
||||||
defaults={
|
defaults={
|
||||||
"vision_model": "ollama/llama3.2-vision:11b",
|
"vision_model": "ollama/llama3.2-vision:11b",
|
||||||
|
|
@ -75,7 +75,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
name="ollama-postgres",
|
name="ollama-postgres",
|
||||||
description="Server-mode tests with Postgres-backed persistence",
|
description="Server-mode tests with Postgres-backed persistence",
|
||||||
env={
|
env={
|
||||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
|
||||||
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
||||||
"POSTGRES_HOST": "127.0.0.1",
|
"POSTGRES_HOST": "127.0.0.1",
|
||||||
"POSTGRES_PORT": "5432",
|
"POSTGRES_PORT": "5432",
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
|
||||||
VLLMInferenceAdapter,
|
VLLMInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||||
{
|
{
|
||||||
"url": "http://fake",
|
"base_url": "http://fake",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
|
||||||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||||
assumption that there is an OpenAI-compatible client object."""
|
assumption that there is an OpenAI-compatible client object."""
|
||||||
|
|
||||||
inference_adapter = adapter_cls(config=config_cls())
|
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
|
||||||
|
|
||||||
inference_adapter.__provider_spec__ = MagicMock()
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ from llama_stack_api import (
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def vllm_inference_adapter():
|
async def vllm_inference_adapter():
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||||
inference_adapter = VLLMInferenceAdapter(config=config)
|
inference_adapter = VLLMInferenceAdapter(config=config)
|
||||||
inference_adapter.model_store = AsyncMock()
|
inference_adapter.model_store = AsyncMock()
|
||||||
await inference_adapter.initialize()
|
await inference_adapter.initialize()
|
||||||
|
|
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
|
||||||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||||
"""
|
"""
|
||||||
# Set up the vLLM adapter
|
# Set up the vLLM adapter
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||||
vllm_adapter.__provider_id__ = "vllm"
|
vllm_adapter.__provider_id__ = "vllm"
|
||||||
await vllm_adapter.initialize()
|
await vllm_adapter.initialize()
|
||||||
|
|
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
|
||||||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||||
"""
|
"""
|
||||||
# Set up the vLLM adapter
|
# Set up the vLLM adapter
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||||
vllm_adapter.__provider_id__ = "vllm"
|
vllm_adapter.__provider_id__ = "vllm"
|
||||||
await vllm_adapter.initialize()
|
await vllm_adapter.initialize()
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
|
||||||
|
|
||||||
async def test_self_hosted_ignores_endpoint():
|
async def test_self_hosted_ignores_endpoint():
|
||||||
adapter = create_adapter(
|
adapter = create_adapter(
|
||||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
|
||||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||||
)
|
)
|
||||||
mock_session = MockSession(MockResponse())
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import get_args, get_origin
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
|
|
@ -41,3 +43,55 @@ class TestProviderConfigurations:
|
||||||
|
|
||||||
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
||||||
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
||||||
|
|
||||||
|
def test_remote_inference_url_standardization(self):
|
||||||
|
"""Verify all remote inference providers use standardized base_url configuration."""
|
||||||
|
provider_registry = get_provider_registry()
|
||||||
|
inference_providers = provider_registry.get("inference", {})
|
||||||
|
|
||||||
|
# Filter for remote providers only
|
||||||
|
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
for provider_type, provider_spec in remote_providers.items():
|
||||||
|
try:
|
||||||
|
config_class_name = provider_spec.config_class
|
||||||
|
config_type = instantiate_class_type(config_class_name)
|
||||||
|
|
||||||
|
# Check that config has base_url field (not url)
|
||||||
|
if hasattr(config_type, "model_fields"):
|
||||||
|
fields = config_type.model_fields
|
||||||
|
|
||||||
|
# Should NOT have 'url' field (old pattern)
|
||||||
|
if "url" in fields:
|
||||||
|
failures.append(
|
||||||
|
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
|
||||||
|
f"Please rename to 'base_url' for consistency."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have 'base_url' field with HttpUrl | None type
|
||||||
|
if "base_url" in fields:
|
||||||
|
field_info = fields["base_url"]
|
||||||
|
annotation = field_info.annotation
|
||||||
|
|
||||||
|
# Check if it's HttpUrl or HttpUrl | None
|
||||||
|
# get_origin() returns Union for (X | Y), None for plain types
|
||||||
|
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
|
||||||
|
is_valid = False
|
||||||
|
if get_origin(annotation) is not None: # It's a Union/Optional
|
||||||
|
if HttpUrl in get_args(annotation):
|
||||||
|
is_valid = True
|
||||||
|
elif annotation == HttpUrl: # Plain HttpUrl without | None
|
||||||
|
is_valid = True
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
failures.append(
|
||||||
|
f"{provider_type}: base_url field has incorrect type annotation. "
|
||||||
|
f"Expected 'HttpUrl | None', got '{annotation}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
|
||||||
|
|
||||||
|
if failures:
|
||||||
|
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue