update anthropic, databricks, tgi, together after get_models rename

This commit is contained in:
Matthew Farrellee 2025-10-06 09:08:09 -04:00
parent 733e72caf9
commit 8658949bbb
4 changed files with 4 additions and 12 deletions

View file

@ -35,5 +35,5 @@ class AnthropicInferenceAdapter(OpenAIMixin):
def get_base_url(self): def get_base_url(self):
return "https://api.anthropic.com/v1" return "https://api.anthropic.com/v1"
async def get_models(self) -> Iterable[str] | None: async def list_provider_model_ids(self) -> Iterable[str]:
return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()] return [m.id async for m in AsyncAnthropic(api_key=self.get_api_key()).models.list()]

View file

@ -33,7 +33,7 @@ class DatabricksInferenceAdapter(OpenAIMixin):
def get_base_url(self) -> str: def get_base_url(self) -> str:
return f"{self.config.url}/serving-endpoints" return f"{self.config.url}/serving-endpoints"
async def get_models(self) -> list[str] | None: async def list_provider_model_ids(self) -> Iterable[str]:
return [ return [
endpoint.name endpoint.name
for endpoint in WorkspaceClient( for endpoint in WorkspaceClient(
@ -68,11 +68,3 @@ class DatabricksInferenceAdapter(OpenAIMixin):
suffix: str | None = None, suffix: str | None = None,
) -> OpenAICompletion: ) -> OpenAICompletion:
raise NotImplementedError() raise NotImplementedError()
async def list_provider_model_ids(self) -> Iterable[str]:
return [
endpoint.name
for endpoint in WorkspaceClient(
host=self.config.url, token=self.get_api_key()
).serving_endpoints.list() # TODO: this is not async
]

View file

@ -35,7 +35,7 @@ class _HfAdapter(OpenAIMixin):
def get_base_url(self): def get_base_url(self):
return self.url return self.url
async def get_models(self) -> Iterable[str] | None: async def list_provider_model_ids(self) -> Iterable[str]:
return [self.model_id] return [self.model_id]
async def openai_embeddings( async def openai_embeddings(

View file

@ -59,7 +59,7 @@ class TogetherInferenceAdapter(OpenAIMixin, NeedsRequestProviderData):
together_api_key = provider_data.together_api_key together_api_key = provider_data.together_api_key
return AsyncTogether(api_key=together_api_key) return AsyncTogether(api_key=together_api_key)
async def get_models(self) -> Iterable[str] | None: async def list_provider_model_ids(self) -> Iterable[str]:
# Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client # Together's /v1/models is not compatible with OpenAI's /v1/models. Together support ticket #13355 -> will not fix, use Together's own client
return [m.id for m in await self._get_client().models.list()] return [m.id for m in await self._get_client().models.list()]