mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: add static embedding metadata to dynamic model listings for providers using OpenAIMixin (#3547)
# What does this PR do? - remove auto-download of ollama embedding models - add embedding model metadata to dynamic listing w/ unit test - add support and tests for allowed_models - removed inference provider models.py files where dynamic listing is enabled - store embedding metadata in embedding_model_metadata field on inference providers - make model_entries optional on ModelRegistryHelper and LiteLLMOpenAIMixin - make OpenAIMixin a ModelRegistryHelper - skip base64 embedding test for remote::ollama, always returns floats - only use OpenAI client for ollama model listing - remove unused build_model_entry function - remove unused get_huggingface_repo function ## Test Plan ci w/ new tests
This commit is contained in:
parent
a50b63906c
commit
b67aef2fc4
43 changed files with 368 additions and 1015 deletions
|
@ -23,6 +23,8 @@ from llama_stack.apis.inference import (
|
|||
Inference,
|
||||
LogProbConfig,
|
||||
Message,
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAICompletion,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
|
@ -32,11 +34,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
from .config import DatabricksImplConfig
|
||||
|
@ -44,29 +42,16 @@ from .config import DatabricksImplConfig
|
|||
logger = get_logger(name=__name__, category="inference::databricks")
|
||||
|
||||
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
EMBEDDING_MODEL_ENTRIES = {
|
||||
"databricks-gte-large-en": ProviderModelEntry(
|
||||
provider_model_id="databricks-gte-large-en",
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 8192,
|
||||
},
|
||||
),
|
||||
"databricks-bge-large-en": ProviderModelEntry(
|
||||
provider_model_id="databricks-bge-large-en",
|
||||
metadata={
|
||||
"embedding_dimension": 1024,
|
||||
"context_length": 512,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class DatabricksInferenceAdapter(
|
||||
OpenAIMixin,
|
||||
Inference,
|
||||
):
|
||||
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
|
||||
embedding_model_metadata = {
|
||||
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
|
||||
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
|
||||
}
|
||||
|
||||
def __init__(self, config: DatabricksImplConfig) -> None:
|
||||
self.config = config
|
||||
|
||||
|
@ -156,11 +141,11 @@ class DatabricksInferenceAdapter(
|
|||
if endpoint.task == "llm/v1/chat":
|
||||
model.model_type = ModelType.llm # this is redundant, but informative
|
||||
elif endpoint.task == "llm/v1/embeddings":
|
||||
if endpoint.name not in EMBEDDING_MODEL_ENTRIES:
|
||||
if endpoint.name not in self.embedding_model_metadata:
|
||||
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
|
||||
continue
|
||||
model.model_type = ModelType.embedding
|
||||
model.metadata = EMBEDDING_MODEL_ENTRIES[endpoint.name].metadata
|
||||
model.metadata = self.embedding_model_metadata[endpoint.name]
|
||||
else:
|
||||
logger.warning(f"Unknown model type, skipping: {endpoint}")
|
||||
continue
|
||||
|
@ -169,13 +154,5 @@ class DatabricksInferenceAdapter(
|
|||
|
||||
return list(self._model_cache.values())
|
||||
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
if not await self.check_model_availability(model.provider_resource_id):
|
||||
raise ValueError(f"Model {model.provider_resource_id} is not available in Databricks workspace.")
|
||||
return model
|
||||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def should_refresh_models(self) -> bool:
|
||||
return False
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue