Added support for IBM watsonx.ai models

This commit is contained in:
Simon Sanchez Viloria 2024-04-20 19:56:20 +02:00
parent e52e4cc1a9
commit 6edb133733
5 changed files with 638 additions and 0 deletions

View file

@ -5331,6 +5331,45 @@ def get_optional_params(
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
)
elif custom_llm_provider == "watsonx":
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
if max_tokens is not None:
optional_params["max_new_tokens"] = max_tokens
if stream:
optional_params["stream"] = stream
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if frequency_penalty is not None:
optional_params["repetition_penalty"] = frequency_penalty
if seed is not None:
optional_params["random_seed"] = seed
if stop is not None:
optional_params["stop_sequences"] = stop
# WatsonX-only parameters
extra_body = {}
if "decoding_method" in passed_params:
extra_body["decoding_method"] = passed_params.pop("decoding_method")
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens"))
if "top_k" in passed_params:
extra_body["top_k"] = passed_params.pop("top_k")
if "truncate_input_tokens" in passed_params:
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens")
if "length_penalty" in passed_params:
extra_body["length_penalty"] = passed_params.pop("length_penalty")
if "time_limit" in passed_params:
extra_body["time_limit"] = passed_params.pop("time_limit")
if "return_options" in passed_params:
extra_body["return_options"] = passed_params.pop("return_options")
optional_params["extra_body"] = (
extra_body # openai client supports `extra_body` param
)
else: # assume passing in params for openai/azure openai
print_verbose(
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
@ -5688,6 +5727,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"frequency_penalty",
"presence_penalty",
]
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXConfig().get_supported_openai_params()
def get_formatted_prompt(
@ -5914,6 +5955,8 @@ def get_llm_provider(
model in litellm.bedrock_models or model in litellm.bedrock_embedding_models
):
custom_llm_provider = "bedrock"
elif model in litellm.watsonx_models:
custom_llm_provider = "watsonx"
# openai embeddings
elif model in litellm.open_ai_embedding_models:
custom_llm_provider = "openai"
@ -9590,6 +9633,26 @@ class CustomStreamWrapper:
"is_finished": chunk["is_finished"],
"finish_reason": finish_reason,
}
def handle_watsonx_stream(self, chunk):
try:
if isinstance(chunk, dict):
pass
elif isinstance(chunk, str):
chunk = json.loads(chunk)
result = chunk.get("results", [])
if len(result) > 0:
text = result[0].get("generated_text", "")
finish_reason = result[0].get("stop_reason")
is_finished = finish_reason != 'not_finished'
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
}
return ""
except Exception as e:
raise e
def model_response_creator(self):
model_response = ModelResponse(stream=True, model=self.model)
@ -9845,6 +9908,12 @@ class CustomStreamWrapper:
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "watsonx":
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":
response_obj = self.handle_openai_text_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]