From 790dcff5e040bd10e57233e23311c2f8179ca9ce Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Tue, 2 Jan 2024 15:32:26 +0530 Subject: [PATCH] (feat) add xinference as an embedding provider --- litellm/__init__.py | 9 ++++++++- litellm/main.py | 28 ++++++++++++++++++++++++++++ litellm/utils.py | 6 +++++- 3 files changed, 41 insertions(+), 2 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 510832a57..8668fe850 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -260,7 +260,13 @@ openai_compatible_endpoints: List = [ ] # this is maintained for Exception Mapping -openai_compatible_providers: List = ["anyscale", "mistral", "deepinfra", "perplexity"] +openai_compatible_providers: List = [ + "anyscale", + "mistral", + "deepinfra", + "perplexity", + "xinference", +] # well supported replicate llms @@ -401,6 +407,7 @@ provider_list: List = [ "maritalk", "voyage", "cloudflare", + "xinference", "custom", # custom apis ] diff --git a/litellm/main.py b/litellm/main.py index befb2733e..ea6b57154 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2072,6 +2072,9 @@ async def aembedding(*args, **kwargs): if ( custom_llm_provider == "openai" or custom_llm_provider == "azure" + or custom_llm_provider == "xinference" + or custom_llm_provider == "voyage" + or custom_llm_provider == "mistral" or custom_llm_provider == "custom_openai" or custom_llm_provider == "anyscale" or custom_llm_provider == "openrouter" @@ -2416,6 +2419,31 @@ def embedding( client=client, aembedding=aembedding, ) + elif custom_llm_provider == "xinference": + api_key = ( + api_key + or litellm.api_key + or get_secret("XINFERENCE_API_KEY") + or "stub-xinference-key" + ) # xinference does not need an api key, pass a stub key if user did not set one + api_base = ( + api_base + or litellm.api_base + or get_secret("XINFERENCE_API_BASE") + or "http://127.0.0.1:9997/v1" + ) + response = openai_chat_completions.embedding( + model=model, + input=input, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + ) else: args = locals() raise ValueError(f"No valid embedding model args passed in - {args}") diff --git a/litellm/utils.py b/litellm/utils.py index d8a75934a..219a32949 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3008,7 +3008,11 @@ def get_optional_params_embeddings( if (k in default_params and v != default_params[k]) } ## raise exception if non-default value passed for non-openai/azure embedding calls - if custom_llm_provider != "openai" and custom_llm_provider != "azure": + if ( + custom_llm_provider != "openai" + and custom_llm_provider != "azure" + and custom_llm_provider not in litellm.openai_compatible_providers + ): if len(non_default_params.keys()) > 0: if litellm.drop_params is True: # drop the unsupported non-default values keys = list(non_default_params.keys())