Merge branch 'main' of https://github.com/santiagxf/llama-stack into santiagxf/azure-ai-inference

This commit is contained in:
Facundo Santiago 2024-11-11 21:15:27 +00:00
commit 8bbc15830e
139 changed files with 6797 additions and 1542 deletions

View file

@ -3,11 +3,12 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .bedrock import BedrockInferenceAdapter
from .config import BedrockConfig
async def get_adapter_impl(config: BedrockConfig, _deps):
from .bedrock import BedrockInferenceAdapter
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
impl = BedrockInferenceAdapter(config)

View file

@ -84,7 +84,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
contents = bedrock_message["content"]
tool_calls = []
text_content = []
text_content = ""
for content in contents:
if "toolUse" in content:
tool_use = content["toolUse"]
@ -98,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
)
)
elif "text" in content:
text_content.append(content["text"])
text_content += content["text"]
return CompletionMessage(
role=role,

View file

@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
from ollama import AsyncClient
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -65,10 +65,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
async def shutdown(self) -> None:
pass
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Dynamic model registration is not supported")
async def register_model(self, model: Model) -> None:
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
raise ValueError(f"Model {model.identifier} is not supported by Ollama")
async def list_models(self) -> List[ModelDef]:
async def list_models(self) -> List[Model]:
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
ret = []
@ -79,10 +80,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
continue
llama_model = ollama_to_llama[r["model"]]
print(f"Found model {llama_model} in Ollama")
ret.append(
ModelDef(
Model(
identifier=llama_model,
llama_model=llama_model,
metadata={
"ollama_model": r["model"],
},

View file

@ -14,7 +14,7 @@ class SampleInferenceImpl(Inference):
def __init__(self, config: SampleConfig):
self.config = config
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: Model) -> None:
# these are the model names the Llama Stack will use to route requests to this provider
# perform validation here if necessary
pass

View file

@ -16,7 +16,7 @@ from llama_models.sku_list import all_registered_models
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -50,14 +50,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
if model.huggingface_repo
}
async def register_model(self, model: ModelDef) -> None:
raise ValueError("Model registration is not supported for HuggingFace models")
async def register_model(self, model: Model) -> None:
pass
async def list_models(self) -> List[ModelDef]:
async def list_models(self) -> List[Model]:
repo = self.model_id
identifier = self.huggingface_repo_to_llama_model_id[repo]
return [
ModelDef(
Model(
identifier=identifier,
llama_model=identifier,
metadata={

View file

@ -13,7 +13,7 @@ from llama_models.sku_list import all_registered_models, resolve_model
from openai import OpenAI
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
get_sampling_options,
@ -44,13 +44,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
async def register_model(self, model: ModelDef) -> None:
async def register_model(self, model: Model) -> None:
raise ValueError("Model registration is not supported for vLLM models")
async def shutdown(self) -> None:
pass
async def list_models(self) -> List[ModelDef]:
async def list_models(self) -> List[Model]:
models = []
for model in self.client.models.list():
repo = model.id
@ -60,7 +60,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
identifier = self.huggingface_repo_to_llama_model_id[repo]
models.append(
ModelDef(
Model(
identifier=identifier,
llama_model=identifier,
)