forked from phoenix-oss/llama-stack-mirror
[bugfix] fix meta-reference agents w/ safety multiple model loading pytest (#694)
# What does this PR do? - Fix broken pytest for meta-reference's agents - Safety model needs to be registered to a different provider id from inference model in order to be recognized ## Test Plan ``` torchrun $CONDA_PREFIX/bin/pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "meta_reference" --safety-shield meta-llama/Llama-Guard-3-1B --inference-model meta-llama/Llama-3.1-8B-Instruct ``` **Before** <img width="845" alt="image" src="https://github.com/user-attachments/assets/83818fe1-2179-4e9c-a753-bf1472a2f01d" /> **After** <img width="851" alt="image" src="https://github.com/user-attachments/assets/1cf8124b-14e2-47bf-80fd-ef8b4b3f6fd9" /> **Other test not broken** ``` pytest -v -s llama_stack/providers/tests/agents/test_agents.py -m "together" --safety-shield meta-llama/Llama-Guard-3-8B --inference-model meta-llama/Llama-3.1-405B-Instruct-FP8 ``` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Ran pre-commit to handle lint / formatting issues. - [ ] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests.
This commit is contained in:
parent
8ba29b19f2
commit
7c1e3daa75
1 changed files with 21 additions and 7 deletions
|
@ -81,14 +81,28 @@ async def agents_stack(request, inference_model, safety_shield):
|
|||
inference_models = (
|
||||
inference_model if isinstance(inference_model, list) else [inference_model]
|
||||
)
|
||||
models = [
|
||||
ModelInput(
|
||||
model_id=model,
|
||||
model_type=ModelType.llm,
|
||||
provider_id=providers["inference"][0].provider_id,
|
||||
|
||||
# NOTE: meta-reference provider needs 1 provider per model, lookup provider_id from provider config
|
||||
model_to_provider_id = {}
|
||||
for provider in providers["inference"]:
|
||||
if "model" in provider.config:
|
||||
model_to_provider_id[provider.config["model"]] = provider.provider_id
|
||||
|
||||
models = []
|
||||
for model in inference_models:
|
||||
if model in model_to_provider_id:
|
||||
provider_id = model_to_provider_id[model]
|
||||
else:
|
||||
provider_id = providers["inference"][0].provider_id
|
||||
|
||||
models.append(
|
||||
ModelInput(
|
||||
model_id=model,
|
||||
model_type=ModelType.llm,
|
||||
provider_id=provider_id,
|
||||
)
|
||||
)
|
||||
for model in inference_models
|
||||
]
|
||||
|
||||
models.append(
|
||||
ModelInput(
|
||||
model_id="all-MiniLM-L6-v2",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue