chore: turn OpenAIMixin into a pydantic.BaseModel (#3671)

# What does this PR do?

- implement get_api_key instead of relying on
LiteLLMOpenAIMixin.get_api_key
 - remove use of LiteLLMOpenAIMixin
 - add default initialize/shutdown methods to OpenAIMixin
 - remove __init__s to allow proper pydantic construction
- remove dead code from vllm adapter and associated / duplicate unit
tests
 - update vllm adapter to use openaimixin for model registration
 - remove ModelRegistryHelper from fireworks & together adapters
 - remove Inference from nvidia adapter
 - complete type hints on embedding_model_metadata
- allow extra fields on OpenAIMixin, for model_store, __provider_id__,
etc
 - new recordings for ollama
 - enhance the list models error handling
- update cerebras (remove cerebras-cloud-sdk) and anthropic (custom
model listing) inference adapters
 - parametrized test_inference_client_caching
- remove cerebras, databricks, fireworks, together from blanket mypy
exclude
 - removed unnecessary litellm deps

## Test Plan

ci
This commit is contained in:
Matthew Farrellee 2025-10-06 11:33:19 -04:00 committed by GitHub
parent 724dac498c
commit d23ed26238
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
131 changed files with 83634 additions and 1760 deletions

View file

@ -11,6 +11,6 @@ async def get_adapter_impl(config: DatabricksImplConfig, _deps):
from .databricks import DatabricksInferenceAdapter
assert isinstance(config, DatabricksImplConfig), f"Unexpected config type: {type(config)}"
impl = DatabricksInferenceAdapter(config)
impl = DatabricksInferenceAdapter(config=config)
await impl.initialize()
return impl

View file

@ -14,12 +14,12 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type
class DatabricksImplConfig(RemoteInferenceProviderConfig):
url: str = Field(
url: str | None = Field(
default=None,
description="The URL for the Databricks model serving endpoint",
)
api_token: SecretStr = Field(
default=SecretStr(None),
default=SecretStr(None), # type: ignore[arg-type]
description="The Databricks API token",
)

View file

@ -9,10 +9,7 @@ from typing import Any
from databricks.sdk import WorkspaceClient
from llama_stack.apis.inference import (
Inference,
OpenAICompletion,
)
from llama_stack.apis.inference import OpenAICompletion
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -21,30 +18,31 @@ from .config import DatabricksImplConfig
logger = get_logger(name=__name__, category="inference::databricks")
class DatabricksInferenceAdapter(
OpenAIMixin,
Inference,
):
class DatabricksInferenceAdapter(OpenAIMixin):
config: DatabricksImplConfig
# source: https://docs.databricks.com/aws/en/machine-learning/foundation-model-apis/supported-models
embedding_model_metadata = {
embedding_model_metadata: dict[str, dict[str, int]] = {
"databricks-gte-large-en": {"embedding_dimension": 1024, "context_length": 8192},
"databricks-bge-large-en": {"embedding_dimension": 1024, "context_length": 512},
}
def __init__(self, config: DatabricksImplConfig) -> None:
self.config = config
def get_api_key(self) -> str:
return self.config.api_token.get_secret_value()
def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints"
async def initialize(self) -> None:
return
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def shutdown(self) -> None:
pass
async def should_refresh_models(self) -> bool:
return False
async def openai_completion(
self,
@ -70,14 +68,3 @@ class DatabricksInferenceAdapter(
suffix: str | None = None,
) -> OpenAICompletion:
raise NotImplementedError()
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]
async def should_refresh_models(self) -> bool:
return False