Merge branch 'main' into inference_refactor

This commit is contained in:
Botao Chen 2024-12-16 16:47:57 -08:00
commit 6a51e2268d
117 changed files with 12698 additions and 2589 deletions

View file

@ -257,6 +257,8 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
endpoints = get_all_api_endpoints()
endpoint_impls = {}
for api, api_endpoints in endpoints.items():
if api not in self.impls:
continue
for endpoint in api_endpoints:
impl = self.impls[api]
func = getattr(impl, endpoint.name)

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining,
}

View file

@ -111,7 +111,7 @@ class InferenceRouter(Inference):
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
@ -144,7 +144,7 @@ class InferenceRouter(Inference):
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding_model:
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)

View file

@ -233,10 +233,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
metadata = {}
if model_type is None:
model_type = ModelType.llm
if (
"embedding_dimension" not in metadata
and model_type == ModelType.embedding_model
):
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
@ -323,8 +320,15 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
)
model = await self.get_object_by_identifier("model", params.embedding_model)
if model is None:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding_model:
if params.embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)

View file

@ -29,7 +29,8 @@ def main(config_path: str):
print("No models found, skipping chat completion test")
return
model_id = models[0].identifier
model_id = next(m.identifier for m in models if "8b" in m.identifier.lower())
print(f"Using model: {model_id}")
response = client.inference.chat_completion(
messages=[UserMessage(content="What is the capital of France?", role="user")],
model_id=model_id,