mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-21 07:52:25 +00:00
Merge branch 'main' of https://github.com/santiagxf/llama-stack into santiagxf/azure-ai-inference
This commit is contained in:
commit
8bbc15830e
139 changed files with 6797 additions and 1542 deletions
|
|
@ -3,11 +3,12 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
from .config import BedrockConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: BedrockConfig, _deps):
|
||||
from .bedrock import BedrockInferenceAdapter
|
||||
|
||||
assert isinstance(config, BedrockConfig), f"Unexpected config type: {type(config)}"
|
||||
|
||||
impl = BedrockInferenceAdapter(config)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
contents = bedrock_message["content"]
|
||||
|
||||
tool_calls = []
|
||||
text_content = []
|
||||
text_content = ""
|
||||
for content in contents:
|
||||
if "toolUse" in content:
|
||||
tool_use = content["toolUse"]
|
||||
|
|
@ -98,7 +98,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
)
|
||||
)
|
||||
elif "text" in content:
|
||||
text_content.append(content["text"])
|
||||
text_content += content["text"]
|
||||
|
||||
return CompletionMessage(
|
||||
role=role,
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer
|
|||
from ollama import AsyncClient
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
|
@ -65,10 +65,11 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Dynamic model registration is not supported")
|
||||
async def register_model(self, model: Model) -> None:
|
||||
if model.identifier not in OLLAMA_SUPPORTED_MODELS:
|
||||
raise ValueError(f"Model {model.identifier} is not supported by Ollama")
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
ollama_to_llama = {v: k for k, v in OLLAMA_SUPPORTED_MODELS.items()}
|
||||
|
||||
ret = []
|
||||
|
|
@ -79,10 +80,10 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
continue
|
||||
|
||||
llama_model = ollama_to_llama[r["model"]]
|
||||
print(f"Found model {llama_model} in Ollama")
|
||||
ret.append(
|
||||
ModelDef(
|
||||
Model(
|
||||
identifier=llama_model,
|
||||
llama_model=llama_model,
|
||||
metadata={
|
||||
"ollama_model": r["model"],
|
||||
},
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ class SampleInferenceImpl(Inference):
|
|||
def __init__(self, config: SampleConfig):
|
||||
self.config = config
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
async def register_model(self, model: Model) -> None:
|
||||
# these are the model names the Llama Stack will use to route requests to this provider
|
||||
# perform validation here if necessary
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from llama_models.sku_list import all_registered_models
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.models import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
|
@ -50,14 +50,14 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
if model.huggingface_repo
|
||||
}
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError("Model registration is not supported for HuggingFace models")
|
||||
async def register_model(self, model: Model) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
repo = self.model_id
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
return [
|
||||
ModelDef(
|
||||
Model(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
metadata={
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from llama_models.sku_list import all_registered_models, resolve_model
|
|||
from openai import OpenAI
|
||||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
get_sampling_options,
|
||||
|
|
@ -44,13 +44,13 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
async def initialize(self) -> None:
|
||||
self.client = OpenAI(base_url=self.config.url, api_key=self.config.api_token)
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
async def register_model(self, model: Model) -> None:
|
||||
raise ValueError("Model registration is not supported for vLLM models")
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
async def list_models(self) -> List[Model]:
|
||||
models = []
|
||||
for model in self.client.models.list():
|
||||
repo = model.id
|
||||
|
|
@ -60,7 +60,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
|
||||
identifier = self.huggingface_repo_to_llama_model_id[repo]
|
||||
models.append(
|
||||
ModelDef(
|
||||
Model(
|
||||
identifier=identifier,
|
||||
llama_model=identifier,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue