[Bug fix ]: Triton /infer handler incompatible with batch responses (#7337)

* migrate triton to base llm http handler

* clean up triton handler.py

* use transform functions for triton

* add TritonConfig

* get openai params for triton

* use triton embedding config

* test_completion_triton_generate_api

* test_completion_triton_infer_api

* fix TritonConfig doc string

* use TritonResponseIterator

* fix triton embeddings

* docs triton chat usage
This commit is contained in:
Ishaan Jaff 2024-12-20 20:59:40 -08:00 committed by GitHub
parent 70a9ea99f2
commit 6107f9f3f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 814 additions and 450 deletions

View file

@ -128,7 +128,6 @@ from .llms.replicate.chat.handler import completion as replicate_chat_completion
from .llms.sagemaker.chat.handler import SagemakerChatHandler
from .llms.sagemaker.completion.handler import SagemakerLLM
from .llms.together_ai.completion.handler import TogetherAITextCompletion
from .llms.triton.completion.handler import TritonChatCompletion
from .llms.vertex_ai import vertex_ai_non_gemini
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
@ -194,7 +193,6 @@ azure_audio_transcriptions = AzureAudioTranscription()
huggingface = Huggingface()
predibase_chat_completions = PredibaseChatCompletion()
codestral_text_completions = CodestralTextCompletion()
triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
bedrock_embedding = BedrockEmbedding()
@ -2711,24 +2709,22 @@ def completion( # type: ignore # noqa: PLR0915
elif custom_llm_provider == "triton":
api_base = litellm.api_base or api_base
model_response = triton_chat_completions.completion(
api_base=api_base,
timeout=timeout, # type: ignore
response = base_llm_http_handler.completion(
model=model,
stream=stream,
messages=messages,
acompletion=acompletion,
api_base=api_base,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion,
client=client,
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
timeout=timeout,
headers=headers,
encoding=encoding,
api_key=api_key,
logging_obj=logging,
)
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare":
api_key = (
api_key
@ -3477,9 +3473,10 @@ def embedding( # noqa: PLR0915
raise ValueError(
"api_base is required for triton. Please pass `api_base`"
)
response = triton_chat_completions.embedding( # type: ignore
response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,