forked from phoenix/litellm-mirror
(fix - watsonx) Fixed issues with watsonx embedding/async endpoints
This commit is contained in:
parent
c7338f9798
commit
06e6f52358
2 changed files with 186 additions and 119 deletions
|
@ -108,6 +108,7 @@ from .llms.databricks import DatabricksChatCompletion
|
|||
from .llms.huggingface_restapi import Huggingface
|
||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.predibase import PredibaseChatCompletion
|
||||
from .llms.watsonx import IBMWatsonXAI
|
||||
from .llms.prompt_templates.factory import (
|
||||
custom_prompt,
|
||||
function_call_prompt,
|
||||
|
@ -152,6 +153,7 @@ triton_chat_completions = TritonChatCompletion()
|
|||
bedrock_chat_completion = BedrockLLM()
|
||||
bedrock_converse_chat_completion = BedrockConverseLLM()
|
||||
vertex_chat_completion = VertexLLM()
|
||||
watsonxai = IBMWatsonXAI()
|
||||
####### COMPLETION ENDPOINTS ################
|
||||
|
||||
|
||||
|
@ -369,6 +371,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "bedrock"
|
||||
or custom_llm_provider == "databricks"
|
||||
or custom_llm_provider == "clarifai"
|
||||
or custom_llm_provider == "watsonx"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -2352,7 +2355,7 @@ def completion(
|
|||
response = response
|
||||
elif custom_llm_provider == "watsonx":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonx.IBMWatsonXAI().completion(
|
||||
response = watsonxai.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
|
@ -2364,6 +2367,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
timeout=timeout, # type: ignore
|
||||
acompletion=acompletion,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
|
@ -3030,6 +3034,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "databricks"
|
||||
or custom_llm_provider == "watsonx"
|
||||
): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -3537,13 +3542,14 @@ def embedding(
|
|||
aembedding=aembedding,
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
response = watsonx.IBMWatsonXAI().embedding(
|
||||
response = watsonxai.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
)
|
||||
else:
|
||||
args = locals()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue