(feat) add xinference as an embedding provider

This commit is contained in:
ishaan-jaff 2024-01-02 15:32:26 +05:30
parent 0d0ee9e108
commit 790dcff5e0
3 changed files with 41 additions and 2 deletions

View file

@ -260,7 +260,13 @@ openai_compatible_endpoints: List = [
]
# this is maintained for Exception Mapping
openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"]
openai_compatible_providers: List = [
"anyscale",
"mistral",
"deepinfra",
"perplexity",
"xinference",
]
# well supported replicate llms
@ -401,6 +407,7 @@ provider_list: List = [
"maritalk",
"voyage",
"cloudflare",
"xinference",
"custom", # custom apis
]

View file

@ -2072,6 +2072,9 @@ async def aembedding(*args, **kwargs):
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "xinference"
or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter"
@ -2416,6 +2419,31 @@ def embedding(
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "xinference":
api_key = (
api_key
or litellm.api_key
or get_secret("XINFERENCE_API_KEY")
or "stub-xinference-key"
) # xinference does not need an api key, pass a stub key if user did not set one
api_base = (
api_base
or litellm.api_base
or get_secret("XINFERENCE_API_BASE")
or "http://127.0.0.1:9997/v1"
)
response = openai_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")

View file

@ -3008,7 +3008,11 @@ def get_optional_params_embeddings(
if (k in default_params and v != default_params[k])
}
## raise exception if non-default value passed for non-openai/azure embedding calls
if custom_llm_provider != "openai" and custom_llm_provider != "azure":
if (
custom_llm_provider != "openai"
and custom_llm_provider != "azure"
and custom_llm_provider not in litellm.openai_compatible_providers
):
if len(non_default_params.keys()) > 0:
if litellm.drop_params is True: # drop the unsupported non-default values
keys = list(non_default_params.keys())