(feat) support for async stream to watsonx provider

This commit is contained in:
Simon Sanchez Viloria 2024-05-06 17:07:21 +02:00
parent 62b3f25398
commit 83a274b54b
3 changed files with 221 additions and 92 deletions

View file

@ -70,6 +70,7 @@ from .llms.azure_text import AzureTextCompletion
from .llms.anthropic import AnthropicChatCompletion
from .llms.anthropic_text import AnthropicTextCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.watsonx import IBMWatsonXAI
from .llms.prompt_templates.factory import (
prompt_factory,
custom_prompt,
@ -105,6 +106,7 @@ anthropic_text_completions = AnthropicTextCompletion()
azure_chat_completions = AzureChatCompletion()
azure_text_completions = AzureTextCompletion()
huggingface = Huggingface()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -308,6 +310,7 @@ async def acompletion(
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider == "anthropic"
or custom_llm_provider == "watsonx"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)
@ -1865,7 +1868,7 @@ def completion(
response = response
elif custom_llm_provider == "watsonx":
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
response = watsonx.IBMWatsonXAI().completion(
response = watsonxai.completion(
model=model,
messages=messages,
custom_prompt_dict=custom_prompt_dict,
@ -1876,6 +1879,7 @@ def completion(
logger_fn=logger_fn,
encoding=encoding,
logging_obj=logging,
acompletion=acompletion,
timeout=timeout,
)
if (
@ -2528,6 +2532,7 @@ async def aembedding(*args, **kwargs):
or custom_llm_provider == "fireworks_ai"
or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "watsonx"
): # currently implemented aiohttp calls for just azure and openai, soon all.
# Await normally
init_response = await loop.run_in_executor(None, func_with_context)
@ -2980,13 +2985,14 @@ def embedding(
aembedding=aembedding,
)
elif custom_llm_provider == "watsonx":
response = watsonx.IBMWatsonXAI().embedding(
response = watsonxai.embedding(
model=model,
input=input,
encoding=encoding,
logging_obj=logging,
optional_params=optional_params,
model_response=EmbeddingResponse(),
aembedding=aembedding,
)
else:
args = locals()