mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 02:17:31 +00:00
Merge branch 'main' into fix/list-command-show-builtin-distros
This commit is contained in:
commit
e5e5e3b5fe
344 changed files with 50428 additions and 80186 deletions
|
|
@ -268,3 +268,50 @@ def test_generate_run_config_from_providers():
|
|||
# Verify config can be parsed back
|
||||
parsed = parse_and_maybe_upgrade_config(config_dict)
|
||||
assert parsed.image_name == "providers-run"
|
||||
|
||||
|
||||
def test_providers_flag_generates_config_with_api_keys():
|
||||
"""Test that --providers flag properly generates provider configs including API keys.
|
||||
|
||||
This tests the fix where sample_run_config() is called to populate
|
||||
API keys and other credentials for remote providers like remote::openai.
|
||||
"""
|
||||
import argparse
|
||||
from unittest.mock import patch
|
||||
|
||||
from llama_stack.cli.stack.run import StackRun
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
subparsers = parser.add_subparsers()
|
||||
stack_run = StackRun(subparsers)
|
||||
|
||||
# Create args with --providers flag set
|
||||
args = argparse.Namespace(
|
||||
providers="inference=remote::openai",
|
||||
config=None,
|
||||
port=8321,
|
||||
image_type=None,
|
||||
image_name=None,
|
||||
enable_ui=False,
|
||||
)
|
||||
|
||||
# Mock _uvicorn_run to prevent starting a server
|
||||
with patch.object(stack_run, "_uvicorn_run"):
|
||||
stack_run._run_stack_run_cmd(args)
|
||||
|
||||
# Read the generated config file
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
|
||||
config_file = DISTRIBS_BASE_DIR / "providers-run" / "run.yaml"
|
||||
with open(config_file) as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
# Verify the provider has config with API keys
|
||||
inference_providers = config_dict["providers"]["inference"]
|
||||
assert len(inference_providers) == 1
|
||||
|
||||
openai_provider = inference_providers[0]
|
||||
assert openai_provider["provider_type"] == "remote::openai"
|
||||
assert openai_provider["config"], "Provider config should not be empty"
|
||||
assert "api_key" in openai_provider["config"], "API key should be in provider config"
|
||||
assert "base_url" in openai_provider["config"], "Base URL should be in provider config"
|
||||
|
|
|
|||
|
|
@ -166,6 +166,14 @@ async def test_models_routing_table(cached_disk_dist_registry):
|
|||
assert "test_provider/test-model" in openai_model_ids
|
||||
assert "test_provider/test-model-2" in openai_model_ids
|
||||
|
||||
# Verify custom_metadata is populated with Llama Stack-specific data
|
||||
for openai_model in openai_models.data:
|
||||
assert openai_model.custom_metadata is not None
|
||||
assert "model_type" in openai_model.custom_metadata
|
||||
assert "provider_id" in openai_model.custom_metadata
|
||||
assert "provider_resource_id" in openai_model.custom_metadata
|
||||
assert openai_model.custom_metadata["provider_id"] == "test_provider"
|
||||
|
||||
# Test get_object_by_identifier
|
||||
model = await table.get_object_by_identifier("model", "test_provider/test-model")
|
||||
assert model is not None
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
|
@ -13,6 +12,8 @@ import pytest
|
|||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||
from llama_stack.apis.inference.inference import TopPSamplingStrategy
|
||||
from llama_stack.apis.resource import ResourceType
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||
|
|
@ -21,193 +22,200 @@ MOCK_DATASET_ID = "default/test-dataset"
|
|||
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||
|
||||
|
||||
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||
def setUp(self):
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
@pytest.fixture
|
||||
def nvidia_eval_setup():
|
||||
"""Set up the NVIDIA eval implementation with mocked dependencies."""
|
||||
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||
|
||||
# Create mock APIs
|
||||
self.datasetio_api = MagicMock()
|
||||
self.datasets_api = MagicMock()
|
||||
self.scoring_api = MagicMock()
|
||||
self.inference_api = MagicMock()
|
||||
self.agents_api = MagicMock()
|
||||
# Create mock APIs
|
||||
datasetio_api = MagicMock()
|
||||
datasets_api = MagicMock()
|
||||
scoring_api = MagicMock()
|
||||
inference_api = MagicMock()
|
||||
agents_api = MagicMock()
|
||||
|
||||
self.config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
config = NVIDIAEvalConfig(
|
||||
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||
)
|
||||
|
||||
self.eval_impl = NVIDIAEvalImpl(
|
||||
config=self.config,
|
||||
datasetio_api=self.datasetio_api,
|
||||
datasets_api=self.datasets_api,
|
||||
scoring_api=self.scoring_api,
|
||||
inference_api=self.inference_api,
|
||||
agents_api=self.agents_api,
|
||||
)
|
||||
eval_impl = NVIDIAEvalImpl(
|
||||
config=config,
|
||||
datasetio_api=datasetio_api,
|
||||
datasets_api=datasets_api,
|
||||
scoring_api=scoring_api,
|
||||
inference_api=inference_api,
|
||||
agents_api=agents_api,
|
||||
)
|
||||
|
||||
# Mock the HTTP request methods
|
||||
self.evaluator_get_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||
)
|
||||
self.evaluator_post_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||
)
|
||||
self.evaluator_delete_patcher = patch(
|
||||
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_delete"
|
||||
)
|
||||
|
||||
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||
self.mock_evaluator_delete = self.evaluator_delete_patcher.start()
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
self.evaluator_get_patcher.stop()
|
||||
self.evaluator_post_patcher.stop()
|
||||
self.evaluator_delete_patcher.stop()
|
||||
|
||||
def _assert_request_body(self, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
call_args = self.mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, run_async):
|
||||
self.run_async = run_async
|
||||
|
||||
def test_register_benchmark(self):
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
# Mock the HTTP request methods
|
||||
with (
|
||||
patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get") as mock_evaluator_get,
|
||||
patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post") as mock_evaluator_post,
|
||||
):
|
||||
yield {
|
||||
"eval_impl": eval_impl,
|
||||
"mock_evaluator_get": mock_evaluator_get,
|
||||
"mock_evaluator_post": mock_evaluator_post,
|
||||
"datasetio_api": datasetio_api,
|
||||
"datasets_api": datasets_api,
|
||||
"scoring_api": scoring_api,
|
||||
"inference_api": inference_api,
|
||||
"agents_api": agents_api,
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type="benchmark",
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
def _assert_request_body(mock_evaluator_post, expected_json):
|
||||
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||
call_args = mock_evaluator_post.call_args
|
||||
actual_json = call_args[0][1]
|
||||
|
||||
# Register the benchmark
|
||||
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||
# Check that all expected keys contain the expected values in the actual JSON
|
||||
for key, value in expected_json.items():
|
||||
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||
if isinstance(value, dict):
|
||||
for nested_key, nested_value in value.items():
|
||||
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||
else:
|
||||
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||
|
||||
def test_unregister_benchmark(self):
|
||||
# Unregister the benchmark
|
||||
self.run_async(self.eval_impl.unregister_benchmark(benchmark_id=MOCK_BENCHMARK_ID))
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_delete.assert_called_once_with(f"/v1/evaluation/configs/nvidia/{MOCK_BENCHMARK_ID}")
|
||||
async def test_register_benchmark(nvidia_eval_setup):
|
||||
eval_impl = nvidia_eval_setup["eval_impl"]
|
||||
mock_evaluator_post = nvidia_eval_setup["mock_evaluator_post"]
|
||||
|
||||
def test_run_eval(self):
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||
)
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Run the Evaluation job
|
||||
result = self.run_async(
|
||||
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once()
|
||||
self._assert_request_body(
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "Llama3.1-8B-Instruct"},
|
||||
eval_config = {
|
||||
"type": "custom",
|
||||
"params": {"parallelism": 8},
|
||||
"tasks": {
|
||||
"qa": {
|
||||
"type": "completion",
|
||||
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
benchmark = Benchmark(
|
||||
provider_id="nvidia",
|
||||
type=ResourceType.benchmark,
|
||||
identifier=MOCK_BENCHMARK_ID,
|
||||
dataset_id=MOCK_DATASET_ID,
|
||||
scoring_functions=["basic::equality"],
|
||||
metadata=eval_config,
|
||||
)
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||
mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Register the benchmark
|
||||
await eval_impl.register_benchmark(benchmark)
|
||||
|
||||
# Verify the Evaluator API was called correctly
|
||||
mock_evaluator_post.assert_called_once()
|
||||
_assert_request_body(
|
||||
mock_evaluator_post, {"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}
|
||||
)
|
||||
|
||||
|
||||
async def test_run_eval(nvidia_eval_setup):
|
||||
eval_impl = nvidia_eval_setup["eval_impl"]
|
||||
mock_evaluator_post = nvidia_eval_setup["mock_evaluator_post"]
|
||||
|
||||
benchmark_config = BenchmarkConfig(
|
||||
eval_candidate=ModelCandidate(
|
||||
type="model",
|
||||
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||
sampling_params=SamplingParams(max_tokens=100, strategy=TopPSamplingStrategy(temperature=0.7)),
|
||||
)
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||
mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
def test_job_status(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||
# Run the Evaluation job
|
||||
result = await eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||
|
||||
# Get the Evaluation job
|
||||
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
# Verify the Evaluator API was called correctly
|
||||
mock_evaluator_post.assert_called_once()
|
||||
_assert_request_body(
|
||||
mock_evaluator_post,
|
||||
{
|
||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||
"target": {"type": "model", "model": "Llama3.1-8B-Instruct"},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.in_progress
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
def test_job_cancel(self):
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||
async def test_job_status(nvidia_eval_setup):
|
||||
eval_impl = nvidia_eval_setup["eval_impl"]
|
||||
mock_evaluator_get = nvidia_eval_setup["mock_evaluator_get"]
|
||||
|
||||
# Cancel the Evaluation job
|
||||
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||
mock_evaluator_get.return_value = mock_evaluator_response
|
||||
|
||||
# Verify the API was called correctly
|
||||
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
# Get the Evaluation job
|
||||
result = await eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
|
||||
|
||||
def test_job_result(self):
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
self.mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
# Verify the result
|
||||
assert isinstance(result, Job)
|
||||
assert result.job_id == "job-123"
|
||||
assert result.status == JobStatus.completed
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||
# Verify the API was called correctly
|
||||
mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert self.mock_evaluator_get.call_count == 2
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
||||
async def test_job_cancel(nvidia_eval_setup):
|
||||
eval_impl = nvidia_eval_setup["eval_impl"]
|
||||
mock_evaluator_post = nvidia_eval_setup["mock_evaluator_post"]
|
||||
|
||||
# Mock Evaluator API response
|
||||
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||
mock_evaluator_post.return_value = mock_evaluator_response
|
||||
|
||||
# Cancel the Evaluation job
|
||||
await eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
|
||||
|
||||
# Verify the API was called correctly
|
||||
mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||
|
||||
|
||||
async def test_job_result(nvidia_eval_setup):
|
||||
eval_impl = nvidia_eval_setup["eval_impl"]
|
||||
mock_evaluator_get = nvidia_eval_setup["mock_evaluator_get"]
|
||||
|
||||
# Mock Evaluator API responses
|
||||
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||
mock_job_results_response = {
|
||||
"id": "job-123",
|
||||
"status": "completed",
|
||||
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||
}
|
||||
mock_evaluator_get.side_effect = [
|
||||
mock_job_status_response, # First call to retrieve job
|
||||
mock_job_results_response, # Second call to retrieve job results
|
||||
]
|
||||
|
||||
# Get the Evaluation job results
|
||||
result = await eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
|
||||
|
||||
# Verify the result
|
||||
assert isinstance(result, EvaluateResponse)
|
||||
assert MOCK_BENCHMARK_ID in result.scores
|
||||
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||
|
||||
# Verify the API was called correctly
|
||||
assert mock_evaluator_get.call_count == 2
|
||||
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
||||
|
|
|
|||
|
|
@ -455,8 +455,8 @@ class TestOpenAIMixinAllowedModels:
|
|||
"""Test cases for allowed_models filtering functionality"""
|
||||
|
||||
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that list_models filters models based on allowed_models set"""
|
||||
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
|
||||
"""Test that list_models filters models based on allowed_models"""
|
||||
mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"]
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
|
@ -470,8 +470,18 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert "final-mock-model-id" not in model_ids
|
||||
|
||||
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that empty allowed_models set allows all models"""
|
||||
assert len(mixin.allowed_models) == 0
|
||||
"""Test that empty allowed_models allows no models"""
|
||||
mixin.config.allowed_models = []
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 0 # No models should be included
|
||||
|
||||
async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that omitted allowed_models allows all models"""
|
||||
assert mixin.config.allowed_models is None
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
|
@ -488,7 +498,7 @@ class TestOpenAIMixinAllowedModels:
|
|||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability respects allowed_models"""
|
||||
mixin.allowed_models = {"final-mock-model-id"}
|
||||
mixin.config.allowed_models = ["final-mock-model-id"]
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
|
|
@ -536,7 +546,7 @@ class TestOpenAIMixinModelRegistration:
|
|||
|
||||
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model registration with allowed_models filtering"""
|
||||
mixin.allowed_models = {"some-mock-model-id"}
|
||||
mixin.config.allowed_models = ["some-mock-model-id"]
|
||||
|
||||
# Test with allowed model
|
||||
allowed_model = Model(
|
||||
|
|
@ -690,7 +700,7 @@ class TestOpenAIMixinCustomListProviderModelIds:
|
|||
mixin = CustomListProviderModelIdsImplementation(
|
||||
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
|
||||
)
|
||||
mixin.allowed_models = ["model-1"]
|
||||
mixin.config.allowed_models = ["model-1"]
|
||||
|
||||
result = await mixin.list_models()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue