From e892a3f7f4cafdc1fd0ae1b94e4f8edd11bd0119 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 7 Oct 2025 09:19:56 -0400 Subject: [PATCH] feat: add refresh_models support to inference adapters (default: false) (#3719) # What does this PR do? inference adapters can now configure `refresh_models: bool` to control periodic model listing from their providers BREAKING CHANGE: together inference adapter default changed. previously always refreshed, now follows config. addresses "models: refresh" on #3517 ## Test Plan ci w/ new tests --- .../providers/inference/remote_anthropic.mdx | 1 + .../docs/providers/inference/remote_azure.mdx | 1 + .../providers/inference/remote_bedrock.mdx | 1 + .../providers/inference/remote_cerebras.mdx | 1 + .../providers/inference/remote_databricks.mdx | 1 + .../providers/inference/remote_fireworks.mdx | 1 + .../providers/inference/remote_gemini.mdx | 1 + docs/docs/providers/inference/remote_groq.mdx | 1 + .../inference/remote_llama-openai-compat.mdx | 1 + .../providers/inference/remote_nvidia.mdx | 1 + .../providers/inference/remote_ollama.mdx | 2 +- .../providers/inference/remote_openai.mdx | 1 + .../inference/remote_passthrough.mdx | 1 + .../providers/inference/remote_runpod.mdx | 1 + .../providers/inference/remote_sambanova.mdx | 1 + docs/docs/providers/inference/remote_tgi.mdx | 1 + .../providers/inference/remote_together.mdx | 1 + .../providers/inference/remote_vertexai.mdx | 1 + docs/docs/providers/inference/remote_vllm.mdx | 2 +- .../providers/inference/remote_watsonx.mdx | 1 + docs/docs/providers/safety/remote_bedrock.mdx | 1 + .../remote/inference/databricks/databricks.py | 3 -- .../remote/inference/ollama/config.py | 6 --- .../remote/inference/ollama/ollama.py | 3 -- .../remote/inference/together/together.py | 3 -- .../providers/remote/inference/vllm/config.py | 4 -- .../providers/remote/inference/vllm/vllm.py | 4 -- .../utils/inference/model_registry.py | 4 ++ .../providers/utils/inference/openai_mixin.py | 2 +- .../providers/inference/test_remote_vllm.py | 40 ------------------- .../utils/inference/test_openai_mixin.py | 8 +++- 31 files changed, 33 insertions(+), 67 deletions(-) diff --git a/docs/docs/providers/inference/remote_anthropic.mdx b/docs/docs/providers/inference/remote_anthropic.mdx index 96162d25c..44c1fcbb1 100644 --- a/docs/docs/providers/inference/remote_anthropic.mdx +++ b/docs/docs/providers/inference/remote_anthropic.mdx @@ -15,6 +15,7 @@ Anthropic inference provider for accessing Claude models and Anthropic's AI serv | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | API key for Anthropic models | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_azure.mdx b/docs/docs/providers/inference/remote_azure.mdx index 721fe429c..56a14c100 100644 --- a/docs/docs/providers/inference/remote_azure.mdx +++ b/docs/docs/providers/inference/remote_azure.mdx @@ -22,6 +22,7 @@ https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `` | No | | Azure API key for Azure | | `api_base` | `` | No | | Azure API base for Azure (e.g., https://your-resource-name.openai.azure.com) | | `api_version` | `str \| None` | No | | Azure API version for Azure (e.g., 2024-12-01-preview) | diff --git a/docs/docs/providers/inference/remote_bedrock.mdx b/docs/docs/providers/inference/remote_bedrock.mdx index 2a5d1b74d..683ec12f8 100644 --- a/docs/docs/providers/inference/remote_bedrock.mdx +++ b/docs/docs/providers/inference/remote_bedrock.mdx @@ -15,6 +15,7 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | diff --git a/docs/docs/providers/inference/remote_cerebras.mdx b/docs/docs/providers/inference/remote_cerebras.mdx index 1a543389d..d364b9884 100644 --- a/docs/docs/providers/inference/remote_cerebras.mdx +++ b/docs/docs/providers/inference/remote_cerebras.mdx @@ -15,6 +15,7 @@ Cerebras inference provider for running models on Cerebras Cloud platform. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `base_url` | `` | No | https://api.cerebras.ai | Base URL for the Cerebras API | | `api_key` | `` | No | | Cerebras API Key | diff --git a/docs/docs/providers/inference/remote_databricks.mdx b/docs/docs/providers/inference/remote_databricks.mdx index 670f8a7f9..d7b0bd38d 100644 --- a/docs/docs/providers/inference/remote_databricks.mdx +++ b/docs/docs/providers/inference/remote_databricks.mdx @@ -15,6 +15,7 @@ Databricks inference provider for running models on Databricks' unified analytic | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `str \| None` | No | | The URL for the Databricks model serving endpoint | | `api_token` | `` | No | | The Databricks API token | diff --git a/docs/docs/providers/inference/remote_fireworks.mdx b/docs/docs/providers/inference/remote_fireworks.mdx index d2c3a664e..cfdfb993c 100644 --- a/docs/docs/providers/inference/remote_fireworks.mdx +++ b/docs/docs/providers/inference/remote_fireworks.mdx @@ -15,6 +15,7 @@ Fireworks AI inference provider for Llama models and other AI models on the Fire | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://api.fireworks.ai/inference/v1 | The URL for the Fireworks server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Fireworks.ai API Key | diff --git a/docs/docs/providers/inference/remote_gemini.mdx b/docs/docs/providers/inference/remote_gemini.mdx index 5222eaa89..a13d1c82d 100644 --- a/docs/docs/providers/inference/remote_gemini.mdx +++ b/docs/docs/providers/inference/remote_gemini.mdx @@ -15,6 +15,7 @@ Google Gemini inference provider for accessing Gemini models and Google's AI ser | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | API key for Gemini models | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_groq.mdx b/docs/docs/providers/inference/remote_groq.mdx index 77516ed1f..1edb4f9ea 100644 --- a/docs/docs/providers/inference/remote_groq.mdx +++ b/docs/docs/providers/inference/remote_groq.mdx @@ -15,6 +15,7 @@ Groq inference provider for ultra-fast inference using Groq's LPU technology. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | The Groq API key | | `url` | `` | No | https://api.groq.com | The URL for the Groq AI server | diff --git a/docs/docs/providers/inference/remote_llama-openai-compat.mdx b/docs/docs/providers/inference/remote_llama-openai-compat.mdx index bcd50f772..ca5830b09 100644 --- a/docs/docs/providers/inference/remote_llama-openai-compat.mdx +++ b/docs/docs/providers/inference/remote_llama-openai-compat.mdx @@ -15,6 +15,7 @@ Llama OpenAI-compatible provider for using Llama models with OpenAI API format. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | The Llama API key | | `openai_compat_api_base` | `` | No | https://api.llama.com/compat/v1/ | The URL for the Llama API server | diff --git a/docs/docs/providers/inference/remote_nvidia.mdx b/docs/docs/providers/inference/remote_nvidia.mdx index 348a42e59..6b5e36180 100644 --- a/docs/docs/providers/inference/remote_nvidia.mdx +++ b/docs/docs/providers/inference/remote_nvidia.mdx @@ -15,6 +15,7 @@ NVIDIA inference provider for accessing NVIDIA NIM models and AI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://integrate.api.nvidia.com | A base url for accessing the NVIDIA NIM | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The NVIDIA API key, only needed of using the hosted service | | `timeout` | `` | No | 60 | Timeout for the HTTP requests | diff --git a/docs/docs/providers/inference/remote_ollama.mdx b/docs/docs/providers/inference/remote_ollama.mdx index f075607d8..e00e34e4a 100644 --- a/docs/docs/providers/inference/remote_ollama.mdx +++ b/docs/docs/providers/inference/remote_ollama.mdx @@ -15,8 +15,8 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | http://localhost:11434 | | -| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_openai.mdx b/docs/docs/providers/inference/remote_openai.mdx index b795d02b1..e0910c809 100644 --- a/docs/docs/providers/inference/remote_openai.mdx +++ b/docs/docs/providers/inference/remote_openai.mdx @@ -15,6 +15,7 @@ OpenAI inference provider for accessing GPT models and other OpenAI services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `api_key` | `str \| None` | No | | API key for OpenAI models | | `base_url` | `` | No | https://api.openai.com/v1 | Base URL for OpenAI API | diff --git a/docs/docs/providers/inference/remote_passthrough.mdx b/docs/docs/providers/inference/remote_passthrough.mdx index 58d5619b8..e356384ad 100644 --- a/docs/docs/providers/inference/remote_passthrough.mdx +++ b/docs/docs/providers/inference/remote_passthrough.mdx @@ -15,6 +15,7 @@ Passthrough inference provider for connecting to any external inference service | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | | The URL for the passthrough endpoint | | `api_key` | `pydantic.types.SecretStr \| None` | No | | API Key for the passthrouth endpoint | diff --git a/docs/docs/providers/inference/remote_runpod.mdx b/docs/docs/providers/inference/remote_runpod.mdx index 92cc66eb1..876532029 100644 --- a/docs/docs/providers/inference/remote_runpod.mdx +++ b/docs/docs/providers/inference/remote_runpod.mdx @@ -15,6 +15,7 @@ RunPod inference provider for running models on RunPod's cloud GPU platform. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `str \| None` | No | | The URL for the Runpod model serving endpoint | | `api_token` | `str \| None` | No | | The API token | diff --git a/docs/docs/providers/inference/remote_sambanova.mdx b/docs/docs/providers/inference/remote_sambanova.mdx index b28471890..9bd7b7613 100644 --- a/docs/docs/providers/inference/remote_sambanova.mdx +++ b/docs/docs/providers/inference/remote_sambanova.mdx @@ -15,6 +15,7 @@ SambaNova inference provider for running models on SambaNova's dataflow architec | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://api.sambanova.ai/v1 | The URL for the SambaNova AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The SambaNova cloud API Key | diff --git a/docs/docs/providers/inference/remote_tgi.mdx b/docs/docs/providers/inference/remote_tgi.mdx index 6ff82cc2b..67fe6d237 100644 --- a/docs/docs/providers/inference/remote_tgi.mdx +++ b/docs/docs/providers/inference/remote_tgi.mdx @@ -15,6 +15,7 @@ Text Generation Inference (TGI) provider for HuggingFace model serving. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | | The URL for the TGI serving endpoint | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_together.mdx b/docs/docs/providers/inference/remote_together.mdx index da232a45b..6df2ca866 100644 --- a/docs/docs/providers/inference/remote_together.mdx +++ b/docs/docs/providers/inference/remote_together.mdx @@ -15,6 +15,7 @@ Together AI inference provider for open-source models and collaborative AI devel | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://api.together.xyz/v1 | The URL for the Together AI server | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The Together AI API Key | diff --git a/docs/docs/providers/inference/remote_vertexai.mdx b/docs/docs/providers/inference/remote_vertexai.mdx index 48da6be24..c182ed485 100644 --- a/docs/docs/providers/inference/remote_vertexai.mdx +++ b/docs/docs/providers/inference/remote_vertexai.mdx @@ -54,6 +54,7 @@ Available Models: | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `project` | `` | No | | Google Cloud project ID for Vertex AI | | `location` | `` | No | us-central1 | Google Cloud location for Vertex AI | diff --git a/docs/docs/providers/inference/remote_vllm.mdx b/docs/docs/providers/inference/remote_vllm.mdx index 598f97b19..fbbd424a3 100644 --- a/docs/docs/providers/inference/remote_vllm.mdx +++ b/docs/docs/providers/inference/remote_vllm.mdx @@ -15,11 +15,11 @@ Remote vLLM inference provider for connecting to vLLM servers. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `str \| None` | No | | The URL for the vLLM model serving endpoint | | `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | -| `refresh_models` | `` | No | False | Whether to refresh models periodically | ## Sample Configuration diff --git a/docs/docs/providers/inference/remote_watsonx.mdx b/docs/docs/providers/inference/remote_watsonx.mdx index 8cd3b2869..33bc5bbc3 100644 --- a/docs/docs/providers/inference/remote_watsonx.mdx +++ b/docs/docs/providers/inference/remote_watsonx.mdx @@ -15,6 +15,7 @@ IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `url` | `` | No | https://us-south.ml.cloud.ibm.com | A base url for accessing the watsonx.ai | | `api_key` | `pydantic.types.SecretStr \| None` | No | | The watsonx API key | | `project_id` | `str \| None` | No | | The Project ID key | diff --git a/docs/docs/providers/safety/remote_bedrock.mdx b/docs/docs/providers/safety/remote_bedrock.mdx index 530a208b5..663a761f0 100644 --- a/docs/docs/providers/safety/remote_bedrock.mdx +++ b/docs/docs/providers/safety/remote_bedrock.mdx @@ -15,6 +15,7 @@ AWS Bedrock safety provider for content moderation using AWS's safety services. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `allowed_models` | `list[str \| None` | No | | List of models that should be registered with the model registry. If None, all models are allowed. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically from the provider | | `aws_access_key_id` | `str \| None` | No | | The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID | | `aws_secret_access_key` | `str \| None` | No | | The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY | | `aws_session_token` | `str \| None` | No | | The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN | diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index f4ad1be94..200b36171 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -41,9 +41,6 @@ class DatabricksInferenceAdapter(OpenAIMixin): ).serving_endpoints.list() # TODO: this is not async ] - async def should_refresh_models(self) -> bool: - return False - async def openai_completion( self, model: str, diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index d2f104e1e..1e4ce9113 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -6,8 +6,6 @@ from typing import Any -from pydantic import Field - from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig DEFAULT_OLLAMA_URL = "http://localhost:11434" @@ -15,10 +13,6 @@ DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(RemoteInferenceProviderConfig): url: str = DEFAULT_OLLAMA_URL - refresh_models: bool = Field( - default=False, - description="Whether to refresh models periodically", - ) @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index e5b08997c..67d0caa54 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -72,9 +72,6 @@ class OllamaInferenceAdapter(OpenAIMixin): f"Ollama Server is not running (message: {r['message']}). Make sure to start it using `ollama serve` in a separate terminal" ) - async def should_refresh_models(self) -> bool: - return self.config.refresh_models - async def health(self) -> HealthResponse: """ Performs a health check by verifying connectivity to the Ollama server. diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index fbefe630f..224de6721 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -63,9 +63,6 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData): # Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client return [m.id for m in await self._get_client().models.list()] - async def should_refresh_models(self) -> bool: - return True - async def openai_embeddings( self, model: str, diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index 86ef3fe26..87c5408d3 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -30,10 +30,6 @@ class VLLMInferenceAdapterConfig(RemoteInferenceProviderConfig): default=True, description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", ) - refresh_models: bool = Field( - default=False, - description="Whether to refresh models periodically", - ) @field_validator("tls_verify") @classmethod diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 4e7884cd2..310eaf7b6 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -53,10 +53,6 @@ class VLLMInferenceAdapter(OpenAIMixin): "You must provide a URL in run.yaml (or via the VLLM_URL environment variable) to use vLLM." ) - async def should_refresh_models(self) -> bool: - # Strictly respecting the refresh_models directive - return self.config.refresh_models - async def health(self) -> HealthResponse: """ Performs a health check by verifying connectivity to the remote vLLM server. diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 4913c2e1f..9d42d68c6 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -24,6 +24,10 @@ class RemoteInferenceProviderConfig(BaseModel): default=None, description="List of models that should be registered with the model registry. If None, all models are allowed.", ) + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically from the provider", + ) # TODO: this class is more confusing than useful right now. We need to make it diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 9137013ee..3c5c5b4de 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -484,7 +484,7 @@ class OpenAIMixin(NeedsRequestProviderData, ABC, BaseModel): return model in self._model_cache async def should_refresh_models(self) -> bool: - return False + return self.config.refresh_models # # The model_dump implementations are to avoid serializing the extra fields, diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 2806f618c..6d6bb20d5 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -186,43 +186,3 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter): assert mock_create_client.call_count == 4 # no cheating assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max" - - -async def test_should_refresh_models(): - """ - Test the should_refresh_models method with different refresh_models configurations. - - This test verifies that: - 1. When refresh_models is True, should_refresh_models returns True regardless of api_token - 2. When refresh_models is False, should_refresh_models returns False regardless of api_token - """ - - # Test case 1: refresh_models is True, api_token is None - config1 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token=None, refresh_models=True) - adapter1 = VLLMInferenceAdapter(config=config1) - result1 = await adapter1.should_refresh_models() - assert result1 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 2: refresh_models is True, api_token is empty string - config2 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="", refresh_models=True) - adapter2 = VLLMInferenceAdapter(config=config2) - result2 = await adapter2.should_refresh_models() - assert result2 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 3: refresh_models is True, api_token is "fake" (default) - config3 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="fake", refresh_models=True) - adapter3 = VLLMInferenceAdapter(config=config3) - result3 = await adapter3.should_refresh_models() - assert result3 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 4: refresh_models is True, api_token is real token - config4 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-123", refresh_models=True) - adapter4 = VLLMInferenceAdapter(config=config4) - result4 = await adapter4.should_refresh_models() - assert result4 is True, "should_refresh_models should return True when refresh_models is True" - - # Test case 5: refresh_models is False, api_token is real token - config5 = VLLMInferenceAdapterConfig(url="http://test.localhost", api_token="real-token-456", refresh_models=False) - adapter5 = VLLMInferenceAdapter(config=config5) - result5 = await adapter5.should_refresh_models() - assert result5 is False, "should_refresh_models should return False when refresh_models is False" diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py index ac4c29fea..2e3a62ca6 100644 --- a/tests/unit/providers/utils/inference/test_openai_mixin.py +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -466,10 +466,16 @@ class TestOpenAIMixinModelRegistration: assert result is None async def test_should_refresh_models(self, mixin): - """Test should_refresh_models method (should always return False)""" + """Test should_refresh_models method returns config value""" + # Default config has refresh_models=False result = await mixin.should_refresh_models() assert result is False + config_with_refresh = RemoteInferenceProviderConfig(refresh_models=True) + mixin_with_refresh = OpenAIMixinImpl(config=config_with_refresh) + result_with_refresh = await mixin_with_refresh.should_refresh_models() + assert result_with_refresh is True + async def test_register_model_error_propagation(self, mixin, mock_client_with_exception, mock_client_context): """Test that errors from provider API are properly propagated during registration""" model = Model(