feat(databricks.py): add embedding model support

This commit is contained in:
Krrish Dholakia 2024-05-23 18:22:03 -07:00
parent d2229dcd21
commit 43353c28b3
7 changed files with 310 additions and 18 deletions

View file

@ -2727,7 +2727,7 @@ def batch_completion_models_all_responses(*args, **kwargs):
### EMBEDDING ENDPOINTS ####################
@client
async def aembedding(*args, **kwargs):
async def aembedding(*args, **kwargs) -> EmbeddingResponse:
"""
Asynchronously calls the `embedding` function with the given arguments and keyword arguments.
@ -2772,12 +2772,13 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "databricks"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(
init_response, ModelResponse
): ## CACHING SCENARIO
if isinstance(init_response, dict):
response = EmbeddingResponse(**init_response)
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
response = init_response
elif asyncio.iscoroutine(init_response):
response = await init_response
@ -2817,7 +2818,7 @@ def embedding(
litellm_logging_obj=None,
logger_fn=None,
**kwargs,
):
) -> EmbeddingResponse:
"""
Embedding function that calls an API to generate embeddings for the given input.
@ -2965,7 +2966,7 @@ def embedding(
)
try:
response = None
logging = litellm_logging_obj
logging: Logging = litellm_logging_obj # type: ignore
logging.update_environment_variables(
model=model,
user=user,
@ -3055,6 +3056,32 @@ def embedding(
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "databricks":
api_base = (
api_base or litellm.api_base or get_secret("DATABRICKS_API_BASE")
) # type: ignore
# set API KEY
api_key = (
api_key
or litellm.api_key
or litellm.databricks_key
or get_secret("DATABRICKS_API_KEY")
) # type: ignore
## EMBEDDING CALL
response = databricks_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "cohere":
cohere_key = (
api_key