Fix agent executor

This commit is contained in:
Ashwin Bharambe 2024-12-16 11:21:51 -08:00
parent 59ce047aea
commit e0731ba353
4 changed files with 62 additions and 16 deletions

View file

@ -9,7 +9,7 @@ import tempfile
import pytest
import pytest_asyncio
from llama_stack.apis.models import ModelInput
from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.agents.meta_reference import (
@ -67,22 +67,42 @@ async def agents_stack(request, inference_model, safety_shield):
for key in ["inference", "safety", "memory", "agents"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if key == "inference":
providers[key].append(
Provider(
provider_id="agents_memory_provider",
provider_type="inline::sentence-transformers",
config={},
)
)
if fixture.provider_data:
provider_data.update(fixture.provider_data)
inference_models = (
inference_model if isinstance(inference_model, list) else [inference_model]
)
models = [
ModelInput(
model_id=model,
model_type=ModelType.llm,
provider_id=providers["inference"][0].provider_id,
)
for model in inference_models
]
models.append(
ModelInput(
model_id="all-MiniLM-L6-v2",
model_type=ModelType.embedding,
provider_id="agents_memory_provider",
metadata={"embedding_dimension": 384},
)
)
test_stack = await construct_stack_for_test(
[Api.agents, Api.inference, Api.safety, Api.memory],
providers,
provider_data,
models=[
ModelInput(
model_id=model,
)
for model in inference_models
],
models=models,
shields=[safety_shield] if safety_shield else [],
)
return test_stack