mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Add tests
This commit is contained in:
parent
bab9d7aaea
commit
d7cbeb4b8c
6 changed files with 345 additions and 5 deletions
|
@ -119,6 +119,7 @@ def client_with_models(
|
|||
embedding_model_id,
|
||||
embedding_dimension,
|
||||
judge_model_id,
|
||||
rerank_model_id,
|
||||
):
|
||||
client = llama_stack_client
|
||||
|
||||
|
@ -151,6 +152,13 @@ def client_with_models(
|
|||
model_type="embedding",
|
||||
metadata={"embedding_dimension": embedding_dimension or 384},
|
||||
)
|
||||
if rerank_model_id and rerank_model_id not in model_ids:
|
||||
rerank_provider = providers[0]
|
||||
client.models.register(
|
||||
model_id=rerank_model_id,
|
||||
provider_id=rerank_provider.provider_id,
|
||||
model_type="rerank",
|
||||
)
|
||||
return client
|
||||
|
||||
|
||||
|
@ -166,7 +174,7 @@ def model_providers(llama_stack_client):
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def skip_if_no_model(request):
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id"]
|
||||
model_fixtures = ["text_model_id", "vision_model_id", "embedding_model_id", "judge_model_id", "shield_id", "rerank_model_id"]
|
||||
test_func = request.node.function
|
||||
|
||||
actual_params = inspect.signature(test_func).parameters.keys()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue