chore: make OpenAIMixin maintainable, turn OpenAIMixin into a pydantic.BaseModel

- implement get_api_key instead of relying on LiteLLMOpenAIMixin.get_api_key
 - remove use of LiteLLMOpenAIMixin
 - add default initialize/shutdown methods to OpenAIMixin
 - remove __init__s to allow proper pydantic construction
 - remove dead code from vllm adapter and associated / duplicate unit tests
 - update vllm adapter to use openaimixin for model registration
 - remove ModelRegistryHelper from fireworks & together adapters
 - remove Inference from nvidia adapter
 - complete type hints on embedding_model_metadata
 - allow extra fields on OpenAIMixin, for model_store, __provider_id__, etc
 - new recordings for ollama
 - enhance the list models error handling w/ new tests
 - update cerebras (remove cerebras-cloud-sdk) and anthropic (custom model listing) inference adapters
 - parametrized test_inference_client_caching
 - remove cerebras, databricks, fireworks, together from blanket mypy exclude
This commit is contained in:
Matthew Farrellee 2025-10-02 20:47:54 -04:00
parent 351c4b98e4
commit fd06717d87
64 changed files with 12901 additions and 1734 deletions

View file

@ -9,11 +9,8 @@ from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
Inference,
Model,
OpenAICompletion,
)
from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -22,30 +19,31 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
OpenAIMixin,
Inference,
):
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
async def get_models(self) -> list[str] | None:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def openai_completion(
self,
@ -71,32 +69,3 @@ class DatabricksInferenceAdapter(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def list_models(self) -> list[Model] | None:
self._model_cache = {} # from OpenAIMixin
ws_client = WorkspaceClient(host=self.config.url, token=self.get_api_key()) # TODO: this is not async
endpoints = ws_client.serving_endpoints.list()
for endpoint in endpoints:
model = Model(
provider_id=self.__provider_id__,
provider_resource_id=endpoint.name,
identifier=endpoint.name,
)
if endpoint.task == "llm/v1/chat":
model.model_type = ModelType.llm # this is redundant, but informative
elif endpoint.task == "llm/v1/embeddings":
if endpoint.name not in self.embedding_model_metadata:
logger.warning(f"No metadata information available for embedding model {endpoint.name}, skipping.")
continue
model.model_type = ModelType.embedding
model.metadata = self.embedding_model_metadata[endpoint.name]
else:
logger.warning(f"Unknown model type, skipping: {endpoint}")
continue
self._model_cache[endpoint.name] = model
return list(self._model_cache.values())
async def should_refresh_models(self) -> bool:
return False