mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
[Bug fix ]: Triton /infer handler incompatible with batch responses (#7337)
* migrate triton to base llm http handler * clean up triton handler.py * use transform functions for triton * add TritonConfig * get openai params for triton * use triton embedding config * test_completion_triton_generate_api * test_completion_triton_infer_api * fix TritonConfig doc string * use TritonResponseIterator * fix triton embeddings * docs triton chat usage
This commit is contained in:
parent
70a9ea99f2
commit
6107f9f3f3
11 changed files with 814 additions and 450 deletions
|
@ -128,7 +128,6 @@ from .llms.replicate.chat.handler import completion as replicate_chat_completion
|
|||
from .llms.sagemaker.chat.handler import SagemakerChatHandler
|
||||
from .llms.sagemaker.completion.handler import SagemakerLLM
|
||||
from .llms.together_ai.completion.handler import TogetherAITextCompletion
|
||||
from .llms.triton.completion.handler import TritonChatCompletion
|
||||
from .llms.vertex_ai import vertex_ai_non_gemini
|
||||
from .llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||
from .llms.vertex_ai.gemini_embeddings.batch_embed_content_handler import (
|
||||
|
@ -194,7 +193,6 @@ azure_audio_transcriptions = AzureAudioTranscription()
|
|||
huggingface = Huggingface()
|
||||
predibase_chat_completions = PredibaseChatCompletion()
|
||||
codestral_text_completions = CodestralTextCompletion()
|
||||
triton_chat_completions = TritonChatCompletion()
|
||||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
bedrock_embedding = BedrockEmbedding()
|
||||
|
@ -2711,24 +2709,22 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
|
||||
elif custom_llm_provider == "triton":
|
||||
api_base = litellm.api_base or api_base
|
||||
model_response = triton_chat_completions.completion(
|
||||
api_base=api_base,
|
||||
timeout=timeout, # type: ignore
|
||||
response = base_llm_http_handler.completion(
|
||||
model=model,
|
||||
stream=stream,
|
||||
messages=messages,
|
||||
acompletion=acompletion,
|
||||
api_base=api_base,
|
||||
model_response=model_response,
|
||||
optional_params=optional_params,
|
||||
logging_obj=logging,
|
||||
stream=stream,
|
||||
acompletion=acompletion,
|
||||
client=client,
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
timeout=timeout,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
return response
|
||||
|
||||
elif custom_llm_provider == "cloudflare":
|
||||
api_key = (
|
||||
api_key
|
||||
|
@ -3477,9 +3473,10 @@ def embedding( # noqa: PLR0915
|
|||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding( # type: ignore
|
||||
response = base_llm_http_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue