mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
feat!: standardize base_url for inference (#4177)
# What does this PR do? Completes #3732 by removing runtime URL transformations and requiring users to provide full URLs in configuration. All providers now use 'base_url' consistently and respect the exact URL provided without appending paths like /v1 or /openai/v1 at runtime. BREAKING CHANGE: Users must update configs to include full URL paths (e.g., http://localhost:11434/v1 instead of http://localhost:11434). Closes #3732 ## Test Plan Existing tests should pass even with the URL changes, due to default URLs being altered. Add unit test to enforce URL standardization across remote inference providers (verifies all use 'base_url' field with HttpUrl | None type) Signed-off-by: Charlie Doern <cdoern@redhat.com>
This commit is contained in:
parent
91f1b352b4
commit
d5cd0eea14
67 changed files with 282 additions and 227 deletions
|
|
@ -24,7 +24,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `api_base` | `HttpUrl` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) |
|
| `base_url` | `HttpUrl \| None` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1) |
|
||||||
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
| `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) |
|
||||||
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
| `api_type` | `str \| None` | No | azure | Azure API type for Azure (e.g., azure) |
|
||||||
|
|
||||||
|
|
@ -32,7 +32,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Cerebras inference provider for running models on Cerebras Cloud platform.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `base_url` | `str` | No | https://api.cerebras.ai | Base URL for the Cerebras API |
|
| `base_url` | `HttpUrl \| None` | No | https://api.cerebras.ai/v1 | Base URL for the Cerebras API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Databricks inference provider for running models on Databricks' unified analytic
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_token` | `SecretStr \| None` | No | | The Databricks API token |
|
| `api_token` | `SecretStr \| None` | No | | The Databricks API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the Databricks model serving endpoint (should include /serving-endpoints path) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.DATABRICKS_HOST:=}
|
base_url: ${env.DATABRICKS_HOST:=}
|
||||||
api_token: ${env.DATABRICKS_TOKEN:=}
|
api_token: ${env.DATABRICKS_TOKEN:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.groq.com | The URL for the Groq AI server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.groq.com/openai/v1 | The URL for the Groq AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `openai_compat_api_base` | `str` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
openai_compat_api_base: https://api.llama.com/compat/v1/
|
base_url: https://api.llama.com/compat/v1/
|
||||||
api_key: ${env.LLAMA_API_KEY}
|
api_key: ${env.LLAMA_API_KEY}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,13 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM |
|
| `base_url` | `HttpUrl \| None` | No | https://integrate.api.nvidia.com/v1 | A base url for accessing the NVIDIA NIM |
|
||||||
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
||||||
| `append_api_version` | `bool` | No | True | When set to false, the API version will not be appended to the base_url. By default, it is true. |
|
|
||||||
| `rerank_model_to_url` | `dict[str, str]` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
| `rerank_model_to_url` | `dict[str, str]` | No | `{'nv-rerank-qa-mistral-4b:1': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking', 'nvidia/nv-rerankqa-mistral-4b-v3': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/nv-rerankqa-mistral-4b-v3/reranking', 'nvidia/llama-3.2-nv-rerankqa-1b-v2': 'https://ai.api.nvidia.com/v1/retrieval/nvidia/llama-3_2-nv-rerankqa-1b-v2/reranking'}` | Mapping of rerank model identifiers to their API endpoints. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,10 @@ Ollama inference provider for running local models through the Ollama runtime.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str` | No | http://localhost:11434 | |
|
| `base_url` | `HttpUrl \| None` | No | http://localhost:11434/v1 | |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `base_url` | `str` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
| `base_url` | `HttpUrl \| None` | No | https://api.openai.com/v1 | Base URL for OpenAI API |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Passthrough inference provider for connecting to any external inference service
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | | The URL for the passthrough endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the passthrough endpoint |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.PASSTHROUGH_URL}
|
base_url: ${env.PASSTHROUGH_URL}
|
||||||
api_key: ${env.PASSTHROUGH_API_KEY}
|
api_key: ${env.PASSTHROUGH_API_KEY}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ RunPod inference provider for running models on RunPod's cloud GPU platform.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_token` | `SecretStr \| None` | No | | The API token |
|
| `api_token` | `SecretStr \| None` | No | | The API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the Runpod model serving endpoint |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.RUNPOD_URL:=}
|
base_url: ${env.RUNPOD_URL:=}
|
||||||
api_token: ${env.RUNPOD_API_TOKEN}
|
api_token: ${env.RUNPOD_API_TOKEN}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ SambaNova inference provider for running models on SambaNova's dataflow architec
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -16,10 +16,10 @@ Text Generation Inference (TGI) provider for HuggingFace model serving.
|
||||||
|-------|------|----------|---------|-------------|
|
|-------|------|----------|---------|-------------|
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `url` | `str` | No | | The URL for the TGI serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the TGI serving endpoint (should include /v1 path) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,11 +17,11 @@ Together AI inference provider for open-source models and collaborative AI devel
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
| `base_url` | `HttpUrl \| None` | No | https://api.together.xyz/v1 | The URL for the Together AI server |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ Remote vLLM inference provider for connecting to vLLM servers.
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_token` | `SecretStr \| None` | No | | The API token |
|
| `api_token` | `SecretStr \| None` | No | | The API token |
|
||||||
| `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint |
|
| `base_url` | `HttpUrl \| None` | No | | The URL for the vLLM model serving endpoint |
|
||||||
| `max_tokens` | `int` | No | 4096 | Maximum number of tokens to generate. |
|
| `max_tokens` | `int` | No | 4096 | Maximum number of tokens to generate. |
|
||||||
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
| `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform
|
||||||
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
| `allowed_models` | `list[str] \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. |
|
||||||
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
| `refresh_models` | `bool` | No | False | Whether to refresh models periodically from the provider |
|
||||||
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
| `api_key` | `SecretStr \| None` | No | | Authentication credential for the provider |
|
||||||
| `url` | `str` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
| `base_url` | `HttpUrl \| None` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai |
|
||||||
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
| `project_id` | `str \| None` | No | | The watsonx.ai project ID |
|
||||||
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
| `timeout` | `int` | No | 60 | Timeout for the HTTP requests |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -287,9 +287,9 @@ start_container() {
|
||||||
# On macOS/Windows, use host.docker.internal to reach host from container
|
# On macOS/Windows, use host.docker.internal to reach host from container
|
||||||
# On Linux with --network host, use localhost
|
# On Linux with --network host, use localhost
|
||||||
if [[ "$(uname)" == "Darwin" ]] || [[ "$(uname)" == *"MINGW"* ]]; then
|
if [[ "$(uname)" == "Darwin" ]] || [[ "$(uname)" == *"MINGW"* ]]; then
|
||||||
OLLAMA_URL="${OLLAMA_URL:-http://host.docker.internal:11434}"
|
OLLAMA_URL="${OLLAMA_URL:-http://host.docker.internal:11434/v1}"
|
||||||
else
|
else
|
||||||
OLLAMA_URL="${OLLAMA_URL:-http://localhost:11434}"
|
OLLAMA_URL="${OLLAMA_URL:-http://localhost:11434/v1}"
|
||||||
fi
|
fi
|
||||||
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OLLAMA_URL=$OLLAMA_URL"
|
DOCKER_ENV_VARS="$DOCKER_ENV_VARS -e OLLAMA_URL=$OLLAMA_URL"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -640,7 +640,7 @@ cmd=( run -d "${PLATFORM_OPTS[@]}" --name llama-stack \
|
||||||
--network llama-net \
|
--network llama-net \
|
||||||
-p "${PORT}:${PORT}" \
|
-p "${PORT}:${PORT}" \
|
||||||
"${server_env_opts[@]}" \
|
"${server_env_opts[@]}" \
|
||||||
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}" \
|
-e OLLAMA_URL="http://ollama-server:${OLLAMA_PORT}/v1" \
|
||||||
"${SERVER_IMAGE}" --port "${PORT}")
|
"${SERVER_IMAGE}" --port "${PORT}")
|
||||||
|
|
||||||
log "🦙 Starting Llama Stack..."
|
log "🦙 Starting Llama Stack..."
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ providers:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,8 @@ providers:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,12 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: sqlite-vec
|
- provider_id: sqlite-vec
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ providers:
|
||||||
- provider_id: vllm-inference
|
- provider_id: vllm-inference
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -17,32 +17,32 @@ providers:
|
||||||
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
- provider_id: ${env.CEREBRAS_API_KEY:+cerebras}
|
||||||
provider_type: remote::cerebras
|
provider_type: remote::cerebras
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai/v1
|
||||||
api_key: ${env.CEREBRAS_API_KEY:=}
|
api_key: ${env.CEREBRAS_API_KEY:=}
|
||||||
- provider_id: ${env.OLLAMA_URL:+ollama}
|
- provider_id: ${env.OLLAMA_URL:+ollama}
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:=http://localhost:11434}
|
base_url: ${env.OLLAMA_URL:=http://localhost:11434/v1}
|
||||||
- provider_id: ${env.VLLM_URL:+vllm}
|
- provider_id: ${env.VLLM_URL:+vllm}
|
||||||
provider_type: remote::vllm
|
provider_type: remote::vllm
|
||||||
config:
|
config:
|
||||||
url: ${env.VLLM_URL:=}
|
base_url: ${env.VLLM_URL:=}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: ${env.TGI_URL:+tgi}
|
- provider_id: ${env.TGI_URL:+tgi}
|
||||||
provider_type: remote::tgi
|
provider_type: remote::tgi
|
||||||
config:
|
config:
|
||||||
url: ${env.TGI_URL:=}
|
base_url: ${env.TGI_URL:=}
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference/v1
|
base_url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY:=}
|
api_key: ${env.FIREWORKS_API_KEY:=}
|
||||||
- provider_id: together
|
- provider_id: together
|
||||||
provider_type: remote::together
|
provider_type: remote::together
|
||||||
config:
|
config:
|
||||||
url: https://api.together.xyz/v1
|
base_url: https://api.together.xyz/v1
|
||||||
api_key: ${env.TOGETHER_API_KEY:=}
|
api_key: ${env.TOGETHER_API_KEY:=}
|
||||||
- provider_id: bedrock
|
- provider_id: bedrock
|
||||||
provider_type: remote::bedrock
|
provider_type: remote::bedrock
|
||||||
|
|
@ -52,9 +52,8 @@ providers:
|
||||||
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
- provider_id: ${env.NVIDIA_API_KEY:+nvidia}
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}
|
base_url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}
|
||||||
api_key: ${env.NVIDIA_API_KEY:=}
|
api_key: ${env.NVIDIA_API_KEY:=}
|
||||||
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True}
|
|
||||||
- provider_id: openai
|
- provider_id: openai
|
||||||
provider_type: remote::openai
|
provider_type: remote::openai
|
||||||
config:
|
config:
|
||||||
|
|
@ -76,18 +75,18 @@ providers:
|
||||||
- provider_id: groq
|
- provider_id: groq
|
||||||
provider_type: remote::groq
|
provider_type: remote::groq
|
||||||
config:
|
config:
|
||||||
url: https://api.groq.com
|
base_url: https://api.groq.com/openai/v1
|
||||||
api_key: ${env.GROQ_API_KEY:=}
|
api_key: ${env.GROQ_API_KEY:=}
|
||||||
- provider_id: sambanova
|
- provider_id: sambanova
|
||||||
provider_type: remote::sambanova
|
provider_type: remote::sambanova
|
||||||
config:
|
config:
|
||||||
url: https://api.sambanova.ai/v1
|
base_url: https://api.sambanova.ai/v1
|
||||||
api_key: ${env.SAMBANOVA_API_KEY:=}
|
api_key: ${env.SAMBANOVA_API_KEY:=}
|
||||||
- provider_id: ${env.AZURE_API_KEY:+azure}
|
- provider_id: ${env.AZURE_API_KEY:+azure}
|
||||||
provider_type: remote::azure
|
provider_type: remote::azure
|
||||||
config:
|
config:
|
||||||
api_key: ${env.AZURE_API_KEY:=}
|
api_key: ${env.AZURE_API_KEY:=}
|
||||||
api_base: ${env.AZURE_API_BASE:=}
|
base_url: ${env.AZURE_API_BASE:=}
|
||||||
api_version: ${env.AZURE_API_VERSION:=}
|
api_version: ${env.AZURE_API_VERSION:=}
|
||||||
api_type: ${env.AZURE_API_TYPE:=}
|
api_type: ${env.AZURE_API_TYPE:=}
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ providers:
|
||||||
- provider_id: watsonx
|
- provider_id: watsonx
|
||||||
provider_type: remote::watsonx
|
provider_type: remote::watsonx
|
||||||
config:
|
config:
|
||||||
url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
base_url: ${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}
|
||||||
api_key: ${env.WATSONX_API_KEY:=}
|
api_key: ${env.WATSONX_API_KEY:=}
|
||||||
project_id: ${env.WATSONX_PROJECT_ID:=}
|
project_id: ${env.WATSONX_PROJECT_ID:=}
|
||||||
vector_io:
|
vector_io:
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import AzureConfig
|
from .config import AzureConfig
|
||||||
|
|
@ -22,4 +20,4 @@ class AzureInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
Returns the Azure API base URL from the configuration.
|
Returns the Azure API base URL from the configuration.
|
||||||
"""
|
"""
|
||||||
return urljoin(str(self.config.api_base), "/openai/v1")
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,9 @@ class AzureProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AzureConfig(RemoteInferenceProviderConfig):
|
class AzureConfig(RemoteInferenceProviderConfig):
|
||||||
api_base: HttpUrl = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com)",
|
default=None,
|
||||||
|
description="Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com/openai/v1)",
|
||||||
)
|
)
|
||||||
api_version: str | None = Field(
|
api_version: str | None = Field(
|
||||||
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
default_factory=lambda: os.getenv("AZURE_API_VERSION"),
|
||||||
|
|
@ -48,14 +49,14 @@ class AzureConfig(RemoteInferenceProviderConfig):
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
api_key: str = "${env.AZURE_API_KEY:=}",
|
api_key: str = "${env.AZURE_API_KEY:=}",
|
||||||
api_base: str = "${env.AZURE_API_BASE:=}",
|
base_url: str = "${env.AZURE_API_BASE:=}",
|
||||||
api_version: str = "${env.AZURE_API_VERSION:=}",
|
api_version: str = "${env.AZURE_API_VERSION:=}",
|
||||||
api_type: str = "${env.AZURE_API_TYPE:=}",
|
api_type: str = "${env.AZURE_API_TYPE:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"api_base": api_base,
|
"base_url": base_url,
|
||||||
"api_version": api_version,
|
"api_version": api_version,
|
||||||
"api_type": api_type,
|
"api_type": api_type,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from urllib.parse import urljoin
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
OpenAIEmbeddingsRequestWithExtraBody,
|
OpenAIEmbeddingsRequestWithExtraBody,
|
||||||
|
|
@ -21,7 +19,7 @@ class CerebrasInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "cerebras_api_key"
|
provider_data_api_key_field: str = "cerebras_api_key"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return urljoin(self.config.base_url, "v1")
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def openai_embeddings(
|
async def openai_embeddings(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -7,12 +7,12 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://api.cerebras.ai"
|
DEFAULT_BASE_URL = "https://api.cerebras.ai/v1"
|
||||||
|
|
||||||
|
|
||||||
class CerebrasProviderDataValidator(BaseModel):
|
class CerebrasProviderDataValidator(BaseModel):
|
||||||
|
|
@ -24,8 +24,8 @@ class CerebrasProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
class CerebrasImplConfig(RemoteInferenceProviderConfig):
|
||||||
base_url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL),
|
default=HttpUrl(os.environ.get("CEREBRAS_BASE_URL", DEFAULT_BASE_URL)),
|
||||||
description="Base URL for the Cerebras API",
|
description="Base URL for the Cerebras API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,9 +21,9 @@ class DatabricksProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Databricks model serving endpoint",
|
description="The URL for the Databricks model serving endpoint (should include /serving-endpoints path)",
|
||||||
)
|
)
|
||||||
auth_credential: SecretStr | None = Field(
|
auth_credential: SecretStr | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
|
|
@ -34,11 +34,11 @@ class DatabricksImplConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.DATABRICKS_HOST:=}",
|
base_url: str = "${env.DATABRICKS_HOST:=}",
|
||||||
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
api_token: str = "${env.DATABRICKS_TOKEN:=}",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"api_token": api_token,
|
"api_token": api_token,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -29,15 +29,21 @@ class DatabricksInferenceAdapter(OpenAIMixin):
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f"{self.config.url}/serving-endpoints"
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
# Filter out None values from endpoint names
|
# Filter out None values from endpoint names
|
||||||
api_token = self._get_api_key_from_config_or_provider_data()
|
api_token = self._get_api_key_from_config_or_provider_data()
|
||||||
|
# WorkspaceClient expects base host without /serving-endpoints suffix
|
||||||
|
base_url_str = str(self.config.base_url)
|
||||||
|
if base_url_str.endswith("/serving-endpoints"):
|
||||||
|
host = base_url_str[:-18] # Remove '/serving-endpoints'
|
||||||
|
else:
|
||||||
|
host = base_url_str
|
||||||
return [
|
return [
|
||||||
endpoint.name # type: ignore[misc]
|
endpoint.name # type: ignore[misc]
|
||||||
for endpoint in WorkspaceClient(
|
for endpoint in WorkspaceClient(
|
||||||
host=self.config.url, token=api_token
|
host=host, token=api_token
|
||||||
).serving_endpoints.list() # TODO: this is not async
|
).serving_endpoints.list() # TODO: this is not async
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
class FireworksImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.fireworks.ai/inference/v1",
|
default=HttpUrl("https://api.fireworks.ai/inference/v1"),
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference/v1",
|
"base_url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -24,4 +24,4 @@ class FireworksInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "fireworks_api_key"
|
provider_data_api_key_field: str = "fireworks_api_key"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return "https://api.fireworks.ai/inference/v1"
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,14 +21,14 @@ class GroqProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class GroqConfig(RemoteInferenceProviderConfig):
|
class GroqConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.groq.com",
|
default=HttpUrl("https://api.groq.com/openai/v1"),
|
||||||
description="The URL for the Groq AI server",
|
description="The URL for the Groq AI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.groq.com",
|
"base_url": "https://api.groq.com/openai/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,4 +15,4 @@ class GroqInferenceAdapter(OpenAIMixin):
|
||||||
provider_data_api_key_field: str = "groq_api_key"
|
provider_data_api_key_field: str = "groq_api_key"
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return f"{self.config.url}/openai/v1"
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,14 +21,14 @@ class LlamaProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
class LlamaCompatConfig(RemoteInferenceProviderConfig):
|
||||||
openai_compat_api_base: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.llama.com/compat/v1/",
|
default=HttpUrl("https://api.llama.com/compat/v1/"),
|
||||||
description="The URL for the Llama API server",
|
description="The URL for the Llama API server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.LLAMA_API_KEY}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"openai_compat_api_base": "https://api.llama.com/compat/v1/",
|
"base_url": "https://api.llama.com/compat/v1/",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class LlamaCompatInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The Llama API base URL
|
:return: The Llama API base URL
|
||||||
"""
|
"""
|
||||||
return self.config.openai_compat_api_base
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def openai_completion(
|
async def openai_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -44,18 +44,14 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
URL of your running NVIDIA NIM and do not need to set the api_key.
|
URL of your running NVIDIA NIM and do not need to set the api_key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com"),
|
default_factory=lambda: os.getenv("NVIDIA_BASE_URL", "https://integrate.api.nvidia.com/v1"),
|
||||||
description="A base url for accessing the NVIDIA NIM",
|
description="A base url for accessing the NVIDIA NIM",
|
||||||
)
|
)
|
||||||
timeout: int = Field(
|
timeout: int = Field(
|
||||||
default=60,
|
default=60,
|
||||||
description="Timeout for the HTTP requests",
|
description="Timeout for the HTTP requests",
|
||||||
)
|
)
|
||||||
append_api_version: bool = Field(
|
|
||||||
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
|
||||||
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
|
||||||
)
|
|
||||||
rerank_model_to_url: dict[str, str] = Field(
|
rerank_model_to_url: dict[str, str] = Field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
"nv-rerank-qa-mistral-4b:1": "https://ai.api.nvidia.com/v1/retrieval/nvidia/reranking",
|
||||||
|
|
@ -68,13 +64,11 @@ class NVIDIAConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com}",
|
base_url: HttpUrl | None = "${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com/v1}",
|
||||||
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
api_key: str = "${env.NVIDIA_API_KEY:=}",
|
||||||
append_api_version: bool = "${env.NVIDIA_APPEND_API_VERSION:=True}",
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"append_api_version": append_api_version,
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
}
|
}
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.url})...")
|
logger.info(f"Initializing NVIDIAInferenceAdapter({self.config.base_url})...")
|
||||||
|
|
||||||
if _is_nvidia_hosted(self.config):
|
if _is_nvidia_hosted(self.config):
|
||||||
if not self.config.auth_credential:
|
if not self.config.auth_credential:
|
||||||
|
|
@ -72,7 +72,7 @@ class NVIDIAInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The NVIDIA API base URL
|
:return: The NVIDIA API base URL
|
||||||
"""
|
"""
|
||||||
return f"{self.config.url}/v1" if self.config.append_api_version else self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -8,4 +8,4 @@ from . import NVIDIAConfig
|
||||||
|
|
||||||
|
|
||||||
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:
|
||||||
return "integrate.api.nvidia.com" in config.url
|
return "integrate.api.nvidia.com" in str(config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,20 +6,22 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field, SecretStr
|
from pydantic import Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
|
|
||||||
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
DEFAULT_OLLAMA_URL = "http://localhost:11434/v1"
|
||||||
|
|
||||||
|
|
||||||
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
class OllamaImplConfig(RemoteInferenceProviderConfig):
|
||||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
url: str = DEFAULT_OLLAMA_URL
|
base_url: HttpUrl | None = Field(default=HttpUrl(DEFAULT_OLLAMA_URL))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(
|
||||||
|
cls, base_url: str = "${env.OLLAMA_URL:=http://localhost:11434/v1}", **kwargs
|
||||||
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -55,17 +55,23 @@ class OllamaInferenceAdapter(OpenAIMixin):
|
||||||
# ollama client attaches itself to the current event loop (sadly?)
|
# ollama client attaches itself to the current event loop (sadly?)
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
if loop not in self._clients:
|
if loop not in self._clients:
|
||||||
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
|
# Ollama client expects base URL without /v1 suffix
|
||||||
|
base_url_str = str(self.config.base_url)
|
||||||
|
if base_url_str.endswith("/v1"):
|
||||||
|
host = base_url_str[:-3]
|
||||||
|
else:
|
||||||
|
host = base_url_str
|
||||||
|
self._clients[loop] = AsyncOllamaClient(host=host)
|
||||||
return self._clients[loop]
|
return self._clients[loop]
|
||||||
|
|
||||||
def get_api_key(self):
|
def get_api_key(self):
|
||||||
return "NO KEY REQUIRED"
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return self.config.url.rstrip("/") + "/v1"
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
|
logger.info(f"checking connectivity to Ollama at `{self.config.base_url}`...")
|
||||||
r = await self.health()
|
r = await self.health()
|
||||||
if r["status"] == HealthStatus.ERROR:
|
if r["status"] == HealthStatus.ERROR:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,8 +21,8 @@ class OpenAIProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OpenAIConfig(RemoteInferenceProviderConfig):
|
class OpenAIConfig(RemoteInferenceProviderConfig):
|
||||||
base_url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.openai.com/v1",
|
default=HttpUrl("https://api.openai.com/v1"),
|
||||||
description="Base URL for OpenAI API",
|
description="Base URL for OpenAI API",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,4 +35,4 @@ class OpenAIInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
Returns the OpenAI API base URL from the configuration.
|
Returns the OpenAI API base URL from the configuration.
|
||||||
"""
|
"""
|
||||||
return self.config.base_url
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,16 +14,16 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
class PassthroughImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the passthrough endpoint",
|
description="The URL for the passthrough endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls, url: str = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
cls, base_url: HttpUrl | None = "${env.PASSTHROUGH_URL}", api_key: str = "${env.PASSTHROUGH_API_KEY}", **kwargs
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -82,8 +82,8 @@ class PassthroughInferenceAdapter(NeedsRequestProviderData, Inference):
|
||||||
|
|
||||||
def _get_passthrough_url(self) -> str:
|
def _get_passthrough_url(self) -> str:
|
||||||
"""Get the passthrough URL from config or provider data."""
|
"""Get the passthrough URL from config or provider data."""
|
||||||
if self.config.url is not None:
|
if self.config.base_url is not None:
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None:
|
if provider_data is None:
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,7 +21,7 @@ class RunpodProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the Runpod model serving endpoint",
|
description="The URL for the Runpod model serving endpoint",
|
||||||
)
|
)
|
||||||
|
|
@ -34,6 +34,6 @@ class RunpodImplConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.RUNPOD_URL:=}",
|
"base_url": "${env.RUNPOD_URL:=}",
|
||||||
"api_token": "${env.RUNPOD_API_TOKEN}",
|
"api_token": "${env.RUNPOD_API_TOKEN}",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class RunpodInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""Get base URL for OpenAI client."""
|
"""Get base URL for OpenAI client."""
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -21,14 +21,14 @@ class SambaNovaProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
class SambaNovaImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.sambanova.ai/v1",
|
default=HttpUrl("https://api.sambanova.ai/v1"),
|
||||||
description="The URL for the SambaNova AI server",
|
description="The URL for the SambaNova AI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY:=}", **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.sambanova.ai/v1",
|
"base_url": "https://api.sambanova.ai/v1",
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,4 +25,4 @@ class SambaNovaInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
:return: The SambaNova base URL
|
:return: The SambaNova base URL
|
||||||
"""
|
"""
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, SecretStr
|
from pydantic import BaseModel, Field, HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -15,18 +15,19 @@ from llama_stack_api import json_schema_type
|
||||||
class TGIImplConfig(RemoteInferenceProviderConfig):
|
class TGIImplConfig(RemoteInferenceProviderConfig):
|
||||||
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
auth_credential: SecretStr | None = Field(default=None, exclude=True)
|
||||||
|
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
description="The URL for the TGI serving endpoint",
|
default=None,
|
||||||
|
description="The URL for the TGI serving endpoint (should include /v1 path)",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.TGI_URL:=}",
|
base_url: str = "${env.TGI_URL:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
|
|
||||||
from huggingface_hub import AsyncInferenceClient, HfApi
|
from huggingface_hub import AsyncInferenceClient, HfApi
|
||||||
from pydantic import SecretStr
|
from pydantic import HttpUrl, SecretStr
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
@ -23,7 +23,7 @@ log = get_logger(name=__name__, category="inference::tgi")
|
||||||
|
|
||||||
|
|
||||||
class _HfAdapter(OpenAIMixin):
|
class _HfAdapter(OpenAIMixin):
|
||||||
url: str
|
base_url: HttpUrl
|
||||||
api_key: SecretStr
|
api_key: SecretStr
|
||||||
|
|
||||||
hf_client: AsyncInferenceClient
|
hf_client: AsyncInferenceClient
|
||||||
|
|
@ -36,7 +36,7 @@ class _HfAdapter(OpenAIMixin):
|
||||||
return "NO KEY REQUIRED"
|
return "NO KEY REQUIRED"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return self.url
|
return self.base_url
|
||||||
|
|
||||||
async def list_provider_model_ids(self) -> Iterable[str]:
|
async def list_provider_model_ids(self) -> Iterable[str]:
|
||||||
return [self.model_id]
|
return [self.model_id]
|
||||||
|
|
@ -50,14 +50,20 @@ class _HfAdapter(OpenAIMixin):
|
||||||
|
|
||||||
class TGIAdapter(_HfAdapter):
|
class TGIAdapter(_HfAdapter):
|
||||||
async def initialize(self, config: TGIImplConfig) -> None:
|
async def initialize(self, config: TGIImplConfig) -> None:
|
||||||
if not config.url:
|
if not config.base_url:
|
||||||
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
raise ValueError("You must provide a URL in run.yaml (or via the TGI_URL environment variable) to use TGI.")
|
||||||
log.info(f"Initializing TGI client with url={config.url}")
|
log.info(f"Initializing TGI client with url={config.base_url}")
|
||||||
self.hf_client = AsyncInferenceClient(model=config.url, provider="hf-inference")
|
# Extract base URL without /v1 for HF client initialization
|
||||||
|
base_url_str = str(config.base_url).rstrip("/")
|
||||||
|
if base_url_str.endswith("/v1"):
|
||||||
|
base_url_for_client = base_url_str[:-3]
|
||||||
|
else:
|
||||||
|
base_url_for_client = base_url_str
|
||||||
|
self.hf_client = AsyncInferenceClient(model=base_url_for_client, provider="hf-inference")
|
||||||
endpoint_info = await self.hf_client.get_endpoint_info()
|
endpoint_info = await self.hf_client.get_endpoint_info()
|
||||||
self.max_tokens = endpoint_info["max_total_tokens"]
|
self.max_tokens = endpoint_info["max_total_tokens"]
|
||||||
self.model_id = endpoint_info["model_id"]
|
self.model_id = endpoint_info["model_id"]
|
||||||
self.url = f"{config.url.rstrip('/')}/v1"
|
self.base_url = config.base_url
|
||||||
self.api_key = SecretStr("NO_KEY")
|
self.api_key = SecretStr("NO_KEY")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,14 +14,14 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
class TogetherImplConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default="https://api.together.xyz/v1",
|
default=HttpUrl("https://api.together.xyz/v1"),
|
||||||
description="The URL for the Together AI server",
|
description="The URL for the Together AI server",
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.together.xyz/v1",
|
"base_url": "https://api.together.xyz/v1",
|
||||||
"api_key": "${env.TOGETHER_API_KEY:=}",
|
"api_key": "${env.TOGETHER_API_KEY:=}",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ from collections.abc import Iterable
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from together import AsyncTogether # type: ignore[import-untyped]
|
from together import AsyncTogether # type: ignore[import-untyped]
|
||||||
from together.constants import BASE_URL # type: ignore[import-untyped]
|
|
||||||
|
|
||||||
from llama_stack.core.request_headers import NeedsRequestProviderData
|
from llama_stack.core.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
@ -42,7 +41,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
|
||||||
provider_data_api_key_field: str = "together_api_key"
|
provider_data_api_key_field: str = "together_api_key"
|
||||||
|
|
||||||
def get_base_url(self):
|
def get_base_url(self):
|
||||||
return BASE_URL
|
return str(self.config.base_url)
|
||||||
|
|
||||||
def _get_client(self) -> AsyncTogether:
|
def _get_client(self) -> AsyncTogether:
|
||||||
together_api_key = None
|
together_api_key = None
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from pydantic import Field, SecretStr, field_validator
|
from pydantic import Field, HttpUrl, SecretStr, field_validator
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -14,7 +14,7 @@ from llama_stack_api import json_schema_type
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||||
url: str | None = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="The URL for the vLLM model serving endpoint",
|
description="The URL for the vLLM model serving endpoint",
|
||||||
)
|
)
|
||||||
|
|
@ -48,11 +48,11 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
url: str = "${env.VLLM_URL:=}",
|
base_url: str = "${env.VLLM_URL:=}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
return {
|
return {
|
||||||
"url": url,
|
"base_url": base_url,
|
||||||
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
"max_tokens": "${env.VLLM_MAX_TOKENS:=4096}",
|
||||||
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
"api_token": "${env.VLLM_API_TOKEN:=fake}",
|
||||||
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
"tls_verify": "${env.VLLM_TLS_VERIFY:=true}",
|
||||||
|
|
|
||||||
|
|
@ -39,12 +39,12 @@ class VLLMInferenceAdapter(OpenAIMixin):
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
"""Get the base URL from config."""
|
"""Get the base URL from config."""
|
||||||
if not self.config.url:
|
if not self.config.base_url:
|
||||||
raise ValueError("No base URL configured")
|
raise ValueError("No base URL configured")
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
if not self.config.url:
|
if not self.config.base_url:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
"You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM."
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@
|
||||||
import os
|
import os
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, HttpUrl
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||||
from llama_stack_api import json_schema_type
|
from llama_stack_api import json_schema_type
|
||||||
|
|
@ -23,7 +23,7 @@ class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class WatsonXConfig(RemoteInferenceProviderConfig):
|
class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
url: str = Field(
|
base_url: HttpUrl | None = Field(
|
||||||
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
description="A base url for accessing the watsonx.ai",
|
description="A base url for accessing the watsonx.ai",
|
||||||
)
|
)
|
||||||
|
|
@ -39,7 +39,7 @@ class WatsonXConfig(RemoteInferenceProviderConfig):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
"base_url": "${env.WATSONX_BASE_URL:=https://us-south.ml.cloud.ibm.com}",
|
||||||
"api_key": "${env.WATSONX_API_KEY:=}",
|
"api_key": "${env.WATSONX_API_KEY:=}",
|
||||||
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
"project_id": "${env.WATSONX_PROJECT_ID:=}",
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -255,7 +255,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_base_url(self) -> str:
|
def get_base_url(self) -> str:
|
||||||
return self.config.url
|
return str(self.config.base_url)
|
||||||
|
|
||||||
# Copied from OpenAIMixin
|
# Copied from OpenAIMixin
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
|
@ -316,7 +316,7 @@ class WatsonXInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
"""
|
"""
|
||||||
Retrieves foundation model specifications from the watsonx.ai API.
|
Retrieves foundation model specifications from the watsonx.ai API.
|
||||||
"""
|
"""
|
||||||
url = f"{self.config.url}/ml/v1/foundation_model_specs?version=2023-10-25"
|
url = f"{str(self.config.base_url)}/ml/v1/foundation_model_specs?version=2023-10-25"
|
||||||
headers = {
|
headers = {
|
||||||
# Note that there is no authorization header. Listing models does not require authentication.
|
# Note that there is no authorization header. Listing models does not require authentication.
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
name="ollama",
|
name="ollama",
|
||||||
description="Local Ollama provider with text + safety models",
|
description="Local Ollama provider with text + safety models",
|
||||||
env={
|
env={
|
||||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
|
||||||
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
||||||
},
|
},
|
||||||
defaults={
|
defaults={
|
||||||
|
|
@ -64,7 +64,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
name="ollama",
|
name="ollama",
|
||||||
description="Local Ollama provider with a vision model",
|
description="Local Ollama provider with a vision model",
|
||||||
env={
|
env={
|
||||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
|
||||||
},
|
},
|
||||||
defaults={
|
defaults={
|
||||||
"vision_model": "ollama/llama3.2-vision:11b",
|
"vision_model": "ollama/llama3.2-vision:11b",
|
||||||
|
|
@ -75,7 +75,7 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
name="ollama-postgres",
|
name="ollama-postgres",
|
||||||
description="Server-mode tests with Postgres-backed persistence",
|
description="Server-mode tests with Postgres-backed persistence",
|
||||||
env={
|
env={
|
||||||
"OLLAMA_URL": "http://0.0.0.0:11434",
|
"OLLAMA_URL": "http://0.0.0.0:11434/v1",
|
||||||
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
||||||
"POSTGRES_HOST": "127.0.0.1",
|
"POSTGRES_HOST": "127.0.0.1",
|
||||||
"POSTGRES_PORT": "5432",
|
"POSTGRES_PORT": "5432",
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
|
||||||
VLLMInferenceAdapter,
|
VLLMInferenceAdapter,
|
||||||
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||||
{
|
{
|
||||||
"url": "http://fake",
|
"base_url": "http://fake",
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
|
||||||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||||
assumption that there is an OpenAI-compatible client object."""
|
assumption that there is an OpenAI-compatible client object."""
|
||||||
|
|
||||||
inference_adapter = adapter_cls(config=config_cls())
|
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
|
||||||
|
|
||||||
inference_adapter.__provider_spec__ = MagicMock()
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ from llama_stack_api import (
|
||||||
|
|
||||||
@pytest.fixture(scope="function")
|
@pytest.fixture(scope="function")
|
||||||
async def vllm_inference_adapter():
|
async def vllm_inference_adapter():
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||||
inference_adapter = VLLMInferenceAdapter(config=config)
|
inference_adapter = VLLMInferenceAdapter(config=config)
|
||||||
inference_adapter.model_store = AsyncMock()
|
inference_adapter.model_store = AsyncMock()
|
||||||
await inference_adapter.initialize()
|
await inference_adapter.initialize()
|
||||||
|
|
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
|
||||||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||||
"""
|
"""
|
||||||
# Set up the vLLM adapter
|
# Set up the vLLM adapter
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||||
vllm_adapter.__provider_id__ = "vllm"
|
vllm_adapter.__provider_id__ = "vllm"
|
||||||
await vllm_adapter.initialize()
|
await vllm_adapter.initialize()
|
||||||
|
|
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
|
||||||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||||
"""
|
"""
|
||||||
# Set up the vLLM adapter
|
# Set up the vLLM adapter
|
||||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||||
vllm_adapter.__provider_id__ = "vllm"
|
vllm_adapter.__provider_id__ = "vllm"
|
||||||
await vllm_adapter.initialize()
|
await vllm_adapter.initialize()
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
|
||||||
|
|
||||||
async def test_self_hosted_ignores_endpoint():
|
async def test_self_hosted_ignores_endpoint():
|
||||||
adapter = create_adapter(
|
adapter = create_adapter(
|
||||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
|
||||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||||
)
|
)
|
||||||
mock_session = MockSession(MockResponse())
|
mock_session = MockSession(MockResponse())
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import get_args, get_origin
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, HttpUrl
|
||||||
|
|
||||||
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
||||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
|
|
@ -41,3 +43,55 @@ class TestProviderConfigurations:
|
||||||
|
|
||||||
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
||||||
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
||||||
|
|
||||||
|
def test_remote_inference_url_standardization(self):
|
||||||
|
"""Verify all remote inference providers use standardized base_url configuration."""
|
||||||
|
provider_registry = get_provider_registry()
|
||||||
|
inference_providers = provider_registry.get("inference", {})
|
||||||
|
|
||||||
|
# Filter for remote providers only
|
||||||
|
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
|
||||||
|
|
||||||
|
failures = []
|
||||||
|
for provider_type, provider_spec in remote_providers.items():
|
||||||
|
try:
|
||||||
|
config_class_name = provider_spec.config_class
|
||||||
|
config_type = instantiate_class_type(config_class_name)
|
||||||
|
|
||||||
|
# Check that config has base_url field (not url)
|
||||||
|
if hasattr(config_type, "model_fields"):
|
||||||
|
fields = config_type.model_fields
|
||||||
|
|
||||||
|
# Should NOT have 'url' field (old pattern)
|
||||||
|
if "url" in fields:
|
||||||
|
failures.append(
|
||||||
|
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
|
||||||
|
f"Please rename to 'base_url' for consistency."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have 'base_url' field with HttpUrl | None type
|
||||||
|
if "base_url" in fields:
|
||||||
|
field_info = fields["base_url"]
|
||||||
|
annotation = field_info.annotation
|
||||||
|
|
||||||
|
# Check if it's HttpUrl or HttpUrl | None
|
||||||
|
# get_origin() returns Union for (X | Y), None for plain types
|
||||||
|
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
|
||||||
|
is_valid = False
|
||||||
|
if get_origin(annotation) is not None: # It's a Union/Optional
|
||||||
|
if HttpUrl in get_args(annotation):
|
||||||
|
is_valid = True
|
||||||
|
elif annotation == HttpUrl: # Plain HttpUrl without | None
|
||||||
|
is_valid = True
|
||||||
|
|
||||||
|
if not is_valid:
|
||||||
|
failures.append(
|
||||||
|
f"{provider_type}: base_url field has incorrect type annotation. "
|
||||||
|
f"Expected 'HttpUrl | None', got '{annotation}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
|
||||||
|
|
||||||
|
if failures:
|
||||||
|
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue