bedrock test for inference fixes

This commit is contained in:
Dinesh Yeduguru 2024-11-07 15:24:45 -08:00
parent e0f227f23c
commit 98c09323a9
3 changed files with 4 additions and 10 deletions

View file

@ -84,7 +84,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
contents = bedrock_message["content"] contents = bedrock_message["content"]
tool_calls = [] tool_calls = []
text_content = [] text_content = ""
for content in contents: for content in contents:
if "toolUse" in content: if "toolUse" in content:
tool_use = content["toolUse"] tool_use = content["toolUse"]
@ -98,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
) )
) )
elif "text" in content: elif "text" in content:
text_content.append(content["text"]) text_content += content["text"]
return CompletionMessage( return CompletionMessage(
role=role, role=role,

View file

@ -10,10 +10,10 @@ import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.inference.bedrock import BedrockConfig
from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
@ -142,19 +142,13 @@ def inference_bedrock() -> ProviderFixture:
INFERENCE_FIXTURES = [ INFERENCE_FIXTURES = [
"meta_reference", "meta_reference",
"ollama", "ollama",
"fireworks", "fireworks",
"together", "together",
"vllm_remote", "vllm_remote",
"remote", "remote",
"bedrock", "bedrock",
,
] ]

View file

@ -10,11 +10,11 @@ import pytest_asyncio
from llama_stack.apis.shields import Shield, ShieldType from llama_stack.apis.shields import Shield, ShieldType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.adapters.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.inline.safety.meta_reference import ( from llama_stack.providers.inline.safety.meta_reference import (
LlamaGuardShieldConfig, LlamaGuardShieldConfig,
SafetyConfig, SafetyConfig,
) )
from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig
from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2 from llama_stack.providers.tests.resolver import resolve_impls_for_test_v2
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture