(feat) add xinference as an embedding provider

This commit is contained in:
ishaan-jaff 2024-01-02 15:32:26 +05:30
parent 0d0ee9e108
commit 790dcff5e0
3 changed files with 41 additions and 2 deletions

View file

@ -2072,6 +2072,9 @@ async def aembedding(*args, **kwargs):
if (
custom_llm_provider == "openai"
or custom_llm_provider == "azure"
or custom_llm_provider == "xinference"
or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter"
@ -2416,6 +2419,31 @@ def embedding(
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "xinference":
api_key = (
api_key
or litellm.api_key
or get_secret("XINFERENCE_API_KEY")
or "stub-xinference-key"
) # xinference does not need an api key, pass a stub key if user did not set one
api_base = (
api_base
or litellm.api_base
or get_secret("XINFERENCE_API_BASE")
or "http://127.0.0.1:9997/v1"
)
response = openai_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
else:
args = locals()
raise ValueError(f"No valid embedding model args passed in - {args}")