forked from phoenix/litellm-mirror
feat - triton embeddings
This commit is contained in:
parent
0cde9473c9
commit
d3550379b0
2 changed files with 140 additions and 0 deletions
|
@ -47,6 +47,7 @@ from .llms import (
|
|||
ai21,
|
||||
sagemaker,
|
||||
bedrock,
|
||||
triton,
|
||||
huggingface_restapi,
|
||||
replicate,
|
||||
aleph_alpha,
|
||||
|
@ -75,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
|
|||
from .llms.anthropic_text import AnthropicTextCompletion
|
||||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.triton import TritonChatCompletion
|
||||
from .llms.prompt_templates.factory import (
|
||||
prompt_factory,
|
||||
custom_prompt,
|
||||
|
@ -112,6 +114,7 @@ azure_chat_completions = AzureChatCompletion()
|
|||
azure_text_completions = AzureTextCompletion()
|
||||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -2622,6 +2625,7 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "voyage"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "triton"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
|
@ -2955,6 +2959,23 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
)
|
||||
elif custom_llm_provider == "triton":
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (
|
||||
optional_params.pop("vertex_project", None)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue