From 98c09323a9fe16229f006073cc86cd5299f46e1a Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Thu, 7 Nov 2024 15:24:45 -0800 Subject: [PATCH] bedrock test for inference fixes --- llama_stack/providers/remote/inference/bedrock/bedrock.py | 4 ++-- llama_stack/providers/tests/inference/fixtures.py | 8 +------- llama_stack/providers/tests/safety/fixtures.py | 2 +- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index f569e0093..d9f82c611 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -84,7 +84,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): contents = bedrock_message["content"] tool_calls = [] - text_content = [] + text_content = "" for content in contents: if "toolUse" in content: tool_use = content["toolUse"] @@ -98,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) ) elif "text" in content: - text_content.append(content["text"]) + text_content += content["text"] return CompletionMessage( role=role, diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index 04ad46fae..7363fa961 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -10,10 +10,10 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.adapters.inference.bedrock import BedrockConfig from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) +from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig @@ -142,19 +142,13 @@ def inference_bedrock() -> ProviderFixture: INFERENCE_FIXTURES = [ - "meta_reference", - "ollama", - "fireworks", - "together", - "vllm_remote", "remote", "bedrock", -, ] diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 41a6c4624..3a374815f 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -10,11 +10,11 @@ import pytest_asyncio from llama_stack.apis.shields import Shield, ShieldType from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.adapters.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.inline.safety.meta_reference import ( LlamaGuardShieldConfig, SafetyConfig, ) +from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from ..conftest import ProviderFixture, remote_stack_fixture