Merge branch 'main' of https://github.com/santiagxf/llama-stack into santiagxf/azure-ai-inference

This commit is contained in:
Facundo Santiago 2024-11-11 21:15:27 +00:00
commit 8bbc15830e
139 changed files with 6797 additions and 1542 deletions

View file

@ -3,11 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bedrock import BedrockInferenceAdapter
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
from .bedrock import BedrockInferenceAdapter
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)

View file

@ -84,7 +84,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
contents = bedrock_message["content"]
tool_calls = []
text_content = []
text_content = ""
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
@ -98,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
)
elif "text" in content:
text_content.append(content["text"])
text_content += content["text"]
return CompletionMessage(
role=role,

View file

@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -65,10 +65,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
async def register_model(self, model: Model) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(f"Model {model.identifier} is not supported by Ollama")
async def list_models(self) -> List[ModelDef]:
async def list_models(self) -> List[Model]:
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
ret = []
@ -79,10 +80,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
continue
llama_model = ollama_to_llama[r["model"]]
print(f"Found model {llama_model} in Ollama")
ret.append(
ModelDef(
Model(
identifier=llama_model,
llama_model=llama_model,
metadata={
"ollama_model": r["model"],
},

View file

@ -14,7 +14,7 @@ class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: Model) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -16,7 +16,7 @@ from llama_models.sku_list import all_registered_models
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -50,14 +50,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if model.huggingface_repo
}
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Model registration is not supported for HuggingFace models")
async def register_model(self, model: Model) -> None:
pass
async def list_models(self) -> List[ModelDef]:
async def list_models(self) -> List[Model]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
ModelDef(
Model(
identifier=identifier,
llama_model=identifier,
metadata={

View file

@ -13,7 +13,7 @@ from llama_models.sku_list import all_registered_models, resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -44,13 +44,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: Model) -> None:
raise ValueError("Model registration is not supported for vLLM models")
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelDef]:
async def list_models(self) -> List[Model]:
models = []
for model in self.client.models.list():
repo = model.id
@ -60,7 +60,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
identifier = self.huggingface_repo_to_llama_model_id[repo]
models.append(
ModelDef(
Model(
identifier=identifier,
llama_model=identifier,
)

View file

@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
BEDROCK_SUPPORTED_SHIELDS = [
ShieldType.generic_content_shield.value,
ShieldType.generic_content_shield,
]
@ -40,32 +40,25 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: ShieldDef) -> None:
raise ValueError("Registering dynamic shields is not supported")
async def list_shields(self) -> List[ShieldDef]:
response = self.bedrock_client.list_guardrails()
shields = []
for guardrail in response["guardrails"]:
# populate the shield def with the guardrail id and version
shield_def = ShieldDef(
identifier=guardrail["id"],
shield_type=ShieldType.generic_content_shield.value,
params={
"guardrailIdentifier": guardrail["id"],
"guardrailVersion": guardrail["version"],
},
async def register_shield(self, shield: Shield) -> None:
response = self.bedrock_client.list_guardrails(
guardrailIdentifier=shield.provider_resource_id,
)
if (
not response["guardrails"]
or len(response["guardrails"]) == 0
or response["guardrails"][0]["version"] != shield.params["guardrailVersion"]
):
raise ValueError(
f"Shield {shield.provider_resource_id} with version {shield.params['guardrailVersion']} not found in Bedrock"
)
self.registered_shields.append(shield_def)
shields.append(shield_def)
return shields
async def run_shield(
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None
self, shield_id: str, messages: List[Message], params: Dict[str, Any] = None
) -> RunShieldResponse:
shield_def = await self.shield_store.get_shield(identifier)
if not shield_def:
raise ValueError(f"Unknown shield {identifier}")
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
"""This is the implementation for the bedrock guardrails. The input to the guardrails is to be of this format
```content = [
@ -81,7 +74,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
They contain content, role . For now we will extract the content and default the "qualifiers": ["query"]
"""
shield_params = shield_def.params
shield_params = shield.params
logger.debug(f"run_shield::{shield_params}::messages={messages}")
# - convert the messages into format Bedrock expects
@ -93,7 +86,7 @@ class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):
)
response = self.bedrock_runtime_client.apply_guardrail(
guardrailIdentifier=shield_params["guardrailIdentifier"],
guardrailIdentifier=shield.provider_resource_id,
guardrailVersion=shield_params["guardrailVersion"],
source="OUTPUT", # or 'INPUT' depending on your use case
content=content_messages,

View file

@ -14,7 +14,7 @@ class SampleSafetyImpl(Safety):
def __init__(self, config: SampleConfig):
self.config = config
async def register_shield(self, shield: ShieldDef) -> None:
async def register_shield(self, shield: Shield) -> None:
# these are the safety shields the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass