feat - triton embeddings

This commit is contained in:
Ishaan Jaff 2024-05-10 18:57:06 -07:00
parent 0cde9473c9
commit d3550379b0
2 changed files with 140 additions and 0 deletions

View file

@ -47,6 +47,7 @@ from .llms import (
ai21,
sagemaker,
bedrock,
triton,
huggingface_restapi,
replicate,
aleph_alpha,
@ -75,6 +76,7 @@ from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.predibase import PredibaseChatCompletion
from .llms.triton import TritonChatCompletion
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -112,6 +114,7 @@ azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
triton_chat_completions = TritonChatCompletion()
####### COMPLETION ENDPOINTS ################
@ -2622,6 +2625,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "voyage"
or custom_llm_provider == "mistral"
or custom_llm_provider == "custom_openai"
or custom_llm_provider == "triton"
or custom_llm_provider == "anyscale"
or custom_llm_provider == "openrouter"
or custom_llm_provider == "deepinfra"
@ -2955,6 +2959,23 @@ def embedding(
optional_params=optional_params,
model_response=EmbeddingResponse(),
)
elif custom_llm_provider == "triton":
if api_base is None:
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding(
model=model,
input=input,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
)
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = (
optional_params.pop("vertex_project", None)