forked from phoenix/litellm-mirror
Added support for IBM watsonx.ai models
This commit is contained in:
parent
e52e4cc1a9
commit
6edb133733
5 changed files with 638 additions and 0 deletions
|
@ -298,6 +298,7 @@ aleph_alpha_models: List = []
|
|||
bedrock_models: List = []
|
||||
deepinfra_models: List = []
|
||||
perplexity_models: List = []
|
||||
watsonx_models: List = []
|
||||
for key, value in model_cost.items():
|
||||
if value.get("litellm_provider") == "openai":
|
||||
open_ai_chat_completion_models.append(key)
|
||||
|
@ -342,6 +343,8 @@ for key, value in model_cost.items():
|
|||
deepinfra_models.append(key)
|
||||
elif value.get("litellm_provider") == "perplexity":
|
||||
perplexity_models.append(key)
|
||||
elif value.get("litellm_provider") == "watsonx":
|
||||
watsonx_models.append(key)
|
||||
|
||||
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
|
||||
openai_compatible_endpoints: List = [
|
||||
|
@ -478,6 +481,7 @@ model_list = (
|
|||
+ perplexity_models
|
||||
+ maritalk_models
|
||||
+ vertex_language_models
|
||||
+ watsonx_models
|
||||
)
|
||||
|
||||
provider_list: List = [
|
||||
|
@ -516,6 +520,7 @@ provider_list: List = [
|
|||
"cloudflare",
|
||||
"xinference",
|
||||
"fireworks_ai",
|
||||
"watsonx",
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
||||
|
@ -537,6 +542,7 @@ models_by_provider: dict = {
|
|||
"deepinfra": deepinfra_models,
|
||||
"perplexity": perplexity_models,
|
||||
"maritalk": maritalk_models,
|
||||
"watsonx": watsonx_models,
|
||||
}
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
|
@ -650,6 +656,7 @@ from .llms.bedrock import (
|
|||
)
|
||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||
from .llms.watsonx import IBMWatsonXConfig
|
||||
from .main import * # type: ignore
|
||||
from .integrations import *
|
||||
from .exceptions import (
|
||||
|
|
|
@ -416,6 +416,32 @@ def format_prompt_togetherai(messages, prompt_format, chat_template):
|
|||
prompt = default_pt(messages)
|
||||
return prompt
|
||||
|
||||
### IBM Granite
|
||||
|
||||
def ibm_granite_pt(messages: list):
|
||||
"""
|
||||
IBM's Granite models uses the template:
|
||||
<|system|> {system_message} <|user|> {user_message} <|assistant|> {assistant_message}
|
||||
|
||||
See: https://www.ibm.com/docs/en/watsonx-as-a-service?topic=solutions-supported-foundation-models
|
||||
"""
|
||||
return custom_prompt(
|
||||
messages=messages,
|
||||
role_dict={
|
||||
'system': {
|
||||
'pre_message': '<|system|>\n',
|
||||
'post_message': '\n',
|
||||
},
|
||||
'user': {
|
||||
'pre_message': '<|user|>\n',
|
||||
'post_message': '\n',
|
||||
},
|
||||
'assistant': {
|
||||
'pre_message': '<|assistant|>\n',
|
||||
'post_message': '\n',
|
||||
}
|
||||
}
|
||||
).strip()
|
||||
|
||||
### ANTHROPIC ###
|
||||
|
||||
|
@ -1327,6 +1353,24 @@ def prompt_factory(
|
|||
return messages
|
||||
elif custom_llm_provider == "azure_text":
|
||||
return azure_text_pt(messages=messages)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
if "granite" in model and "chat" in model:
|
||||
# granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template
|
||||
return ibm_granite_pt(messages=messages)
|
||||
elif "ibm-mistral" in model:
|
||||
# models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template
|
||||
return mistral_instruct_pt(messages=messages)
|
||||
elif "meta-llama/llama-3" in model and "instruct" in model:
|
||||
return custom_prompt(
|
||||
role_dict={
|
||||
"system": {"pre_message": "<|start_header_id|>system<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||
"user": {"pre_message": "<|start_header_id|>user<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||
"assistant": {"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n", "post_message": "<|eot_id|>"},
|
||||
},
|
||||
messages=messages,
|
||||
initial_prompt_value="<|begin_of_text|>",
|
||||
# final_prompt_value="\n",
|
||||
)
|
||||
try:
|
||||
if "meta-llama/llama-2" in model and "chat" in model:
|
||||
return llama_2_chat_pt(messages=messages)
|
||||
|
|
480
litellm/llms/watsonx.py
Normal file
480
litellm/llms/watsonx.py
Normal file
|
@ -0,0 +1,480 @@
|
|||
import json, types, time
|
||||
from typing import Callable, Optional, Any, Union, List
|
||||
|
||||
import httpx
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
||||
|
||||
from .prompt_templates import factory as ptf
|
||||
|
||||
class WatsonxError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://https://us-south.ml.cloud.ibm.com"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class IBMWatsonXConfig:
|
||||
"""
|
||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation
|
||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||
|
||||
Supported params for all available watsonx.ai foundational models.
|
||||
|
||||
- `decoding_method` (str): One of "greedy" or "sample"
|
||||
|
||||
- `temperature` (float): Sets the model temperature for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `max_new_tokens` (integer): Maximum length of the generated tokens.
|
||||
|
||||
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||
|
||||
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||
|
||||
- `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point.
|
||||
|
||||
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||
|
||||
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||
|
||||
- `stream` (bool): If True, the model will return a stream of responses.
|
||||
|
||||
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks".
|
||||
|
||||
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||
|
||||
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||
|
||||
- `random_seed` (integer): Random seed for text generation.
|
||||
|
||||
- `guardrails` (bool): Enable guardrails for harmful content.
|
||||
|
||||
- `guardrails_hap_params` (dict): Guardrails for harmful content.
|
||||
|
||||
- `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information.
|
||||
|
||||
- `concurrency_limit` (integer): Maximum number of concurrent requests.
|
||||
|
||||
- `async_mode` (bool): Enable async mode.
|
||||
|
||||
- `verify` (bool): Verify the SSL certificate of calls to the watsonx url.
|
||||
|
||||
- `validate` (bool): Validate the model_id at initialization.
|
||||
|
||||
- `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance.
|
||||
|
||||
- `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with.
|
||||
"""
|
||||
decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior
|
||||
temperature: Optional[float] = None #
|
||||
min_new_tokens: Optional[int] = None
|
||||
max_new_tokens: Optional[int] = litellm.max_tokens
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
random_seed: Optional[int] = None # e.g 42
|
||||
repetition_penalty: Optional[float] = None
|
||||
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||
time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds)
|
||||
return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False}
|
||||
truncate_input_tokens: Optional[int] = None # e.g 512
|
||||
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||
stream: Optional[bool] = False
|
||||
# other inference params
|
||||
guardrails: Optional[bool] = False # enable guardrails
|
||||
guardrails_hap_params: Optional[dict] = None # guardrails for harmful content
|
||||
guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information
|
||||
concurrency_limit: Optional[int] = 10 # max number of concurrent requests
|
||||
async_mode: Optional[bool] = False # enable async mode
|
||||
verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url
|
||||
validate: Optional[bool] = False # validate the model_id at initialization
|
||||
model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance
|
||||
watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
decoding_method: Optional[str] = None,
|
||||
temperature: Optional[float] = None,
|
||||
min_new_tokens: Optional[int] = None,
|
||||
max_new_tokens: Optional[
|
||||
int
|
||||
] = litellm.max_tokens, # petals requires max tokens to be set
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
random_seed: Optional[int] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
stop_sequences: Optional[List[str]] = None,
|
||||
time_limit: Optional[int] = None,
|
||||
return_options: Optional[dict] = None,
|
||||
truncate_input_tokens: Optional[int] = None,
|
||||
length_penalty: Optional[dict] = None,
|
||||
stream: Optional[bool] = False,
|
||||
guardrails: Optional[bool] = False,
|
||||
guardrails_hap_params: Optional[dict] = None,
|
||||
guardrails_pii_params: Optional[dict] = None,
|
||||
concurrency_limit: Optional[int] = 10,
|
||||
async_mode: Optional[bool] = False,
|
||||
verify: Optional[Union[bool,str]] = None,
|
||||
validate: Optional[bool] = False,
|
||||
model_inference: Optional[object] = None,
|
||||
watsonx_client: Optional[object] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature", # equivalent to temperature
|
||||
"max_tokens", # equivalent to max_new_tokens
|
||||
"top_p", # equivalent to top_p
|
||||
"frequency_penalty", # equivalent to repetition_penalty
|
||||
"stop", # equivalent to stop_sequences
|
||||
"seed", # equivalent to random_seed
|
||||
"stream", # equivalent to stream
|
||||
]
|
||||
|
||||
|
||||
def init_watsonx_model(
|
||||
model_id: str,
|
||||
url: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
project_id: Optional[str] = None,
|
||||
space_id: Optional[str] = None,
|
||||
wx_credentials: Optional[dict] = None,
|
||||
region_name: Optional[str] = None,
|
||||
verify: Optional[Union[bool,str]] = None,
|
||||
validate: Optional[bool] = False,
|
||||
watsonx_client: Optional[object] = None,
|
||||
model_params: Optional[dict] = None,
|
||||
):
|
||||
"""
|
||||
Initialize a watsonx.ai model for inference.
|
||||
|
||||
Args:
|
||||
|
||||
model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/<deployment_id>' and the space_id to the deploymend space should be provided.
|
||||
url (str): The URL of the watsonx.ai instance.
|
||||
api_key (str): The API key for the watsonx.ai instance.
|
||||
project_id (str): The project ID for the watsonx.ai instance.
|
||||
space_id (str): The space ID for the deployment space.
|
||||
wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance.
|
||||
region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south').
|
||||
verify (bool): Whether to verify the SSL certificate of calls to the watsonx url.
|
||||
validate (bool): Whether to validate the model_id at initialization.
|
||||
watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client.
|
||||
model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters).
|
||||
"""
|
||||
|
||||
from ibm_watsonx_ai import APIClient
|
||||
from ibm_watsonx_ai.foundation_models import ModelInference
|
||||
|
||||
|
||||
if wx_credentials is not None:
|
||||
if 'apikey' not in wx_credentials and 'api_key' in wx_credentials:
|
||||
wx_credentials['apikey'] = wx_credentials.pop('api_key')
|
||||
if 'apikey' not in wx_credentials:
|
||||
raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials")
|
||||
|
||||
if url is None:
|
||||
url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL")
|
||||
if api_key is None:
|
||||
api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY")
|
||||
if project_id is None:
|
||||
project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID")
|
||||
if region_name is None:
|
||||
region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME")
|
||||
if space_id is None:
|
||||
space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID")
|
||||
|
||||
|
||||
## CHECK IS 'os.environ/' passed in
|
||||
# Define the list of parameters to check
|
||||
params_to_check = (url, api_key, project_id, space_id, region_name)
|
||||
# Iterate over parameters and update if needed
|
||||
for i, param in enumerate(params_to_check):
|
||||
if param and param.startswith("os.environ/"):
|
||||
params_to_check[i] = get_secret(param)
|
||||
# Assign updated values back to parameters
|
||||
url, api_key, project_id, space_id, region_name = params_to_check
|
||||
|
||||
### SET WATSONX URL
|
||||
if url is not None or watsonx_client is not None or wx_credentials is not None:
|
||||
pass
|
||||
elif region_name is not None:
|
||||
url = f"https://{region_name}.ml.cloud.ibm.com"
|
||||
else:
|
||||
raise WatsonxError(
|
||||
message="Watsonx URL not set: set WX_URL env variable or in .env file",
|
||||
status_code=401,
|
||||
)
|
||||
if watsonx_client is not None and project_id is None:
|
||||
project_id = watsonx_client.project_id
|
||||
|
||||
if model_id.startswith("deployment/"):
|
||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||
assert space_id is not None, "space_id is required for deployment models"
|
||||
deployment_id = '/'.join(model_id.split("/")[1:])
|
||||
model_id = None
|
||||
else:
|
||||
deployment_id = None
|
||||
|
||||
if watsonx_client is not None:
|
||||
model = ModelInference(
|
||||
model_id=model_id,
|
||||
params=model_params,
|
||||
api_client=watsonx_client,
|
||||
project_id=project_id,
|
||||
deployment_id=deployment_id,
|
||||
verify=verify,
|
||||
validate=validate,
|
||||
space_id=space_id,
|
||||
)
|
||||
elif wx_credentials is not None:
|
||||
model = ModelInference(
|
||||
model_id=model_id,
|
||||
params=model_params,
|
||||
credentials=wx_credentials,
|
||||
project_id=project_id,
|
||||
deployment_id=deployment_id,
|
||||
verify=verify,
|
||||
validate=validate,
|
||||
space_id=space_id,
|
||||
)
|
||||
elif api_key is not None:
|
||||
model = ModelInference(
|
||||
model_id=model_id,
|
||||
params=model_params,
|
||||
credentials={
|
||||
"apikey": api_key,
|
||||
"url": url,
|
||||
},
|
||||
project_id=project_id,
|
||||
deployment_id=deployment_id,
|
||||
verify=verify,
|
||||
validate=validate,
|
||||
space_id=space_id,
|
||||
)
|
||||
else:
|
||||
raise WatsonxError(500, "WatsonX credentials not passed or could not be found.")
|
||||
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_dict = custom_prompt_dict[model]
|
||||
prompt = ptf.custom_prompt(
|
||||
messages=messages,
|
||||
role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")),
|
||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""),
|
||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||
)
|
||||
return prompt
|
||||
elif provider == "ibm":
|
||||
prompt = ptf.prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
elif provider == "ibm-mistralai":
|
||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||
else:
|
||||
prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx')
|
||||
return prompt
|
||||
|
||||
|
||||
"""
|
||||
IBM watsonx.ai AUTH Keys/Vars
|
||||
os.environ['WX_URL'] = ""
|
||||
os.environ['WX_API_KEY'] = ""
|
||||
os.environ['WX_PROJECT_ID'] = ""
|
||||
"""
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params:Optional[dict]=None,
|
||||
litellm_params:Optional[dict]=None,
|
||||
logger_fn=None,
|
||||
timeout:float=None,
|
||||
):
|
||||
from ibm_watsonx_ai.foundation_models import Model, ModelInference
|
||||
|
||||
try:
|
||||
stream = optional_params.pop("stream", False)
|
||||
extra_generate_params = dict(
|
||||
guardrails=optional_params.pop("guardrails", False),
|
||||
guardrails_hap_params=optional_params.pop("guardrails_hap_params", None),
|
||||
guardrails_pii_params=optional_params.pop("guardrails_pii_params", None),
|
||||
concurrency_limit=optional_params.pop("concurrency_limit", 10),
|
||||
async_mode=optional_params.pop("async_mode", False),
|
||||
)
|
||||
if timeout is not None and optional_params.get("time_limit") is None:
|
||||
# the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds)
|
||||
optional_params['time_limit'] = max(0, int(timeout*1000))
|
||||
extra_body_params = optional_params.pop("extra_body", {})
|
||||
optional_params.update(extra_body_params)
|
||||
# LOAD CONFIG
|
||||
config = IBMWatsonXConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
model_inference = optional_params.pop("model_inference", None)
|
||||
if model_inference is None:
|
||||
# INIT MODEL
|
||||
model_client:ModelInference = init_watsonx_model(
|
||||
model_id=model,
|
||||
url=optional_params.pop("url", None),
|
||||
api_key=optional_params.pop("api_key", None),
|
||||
project_id=optional_params.pop("project_id", None),
|
||||
space_id=optional_params.pop("space_id", None),
|
||||
wx_credentials=optional_params.pop("wx_credentials", None),
|
||||
region_name=optional_params.pop("region_name", None),
|
||||
verify=optional_params.pop("verify", None),
|
||||
validate=optional_params.pop("validate", False),
|
||||
watsonx_client=optional_params.pop("watsonx_client", None),
|
||||
model_params=optional_params,
|
||||
)
|
||||
else:
|
||||
model_client:ModelInference = model_inference
|
||||
model = model_client.model_id
|
||||
|
||||
# MAKE PROMPT
|
||||
provider = model.split("/")[0]
|
||||
model_name = '/'.join(model.split("/")[1:])
|
||||
prompt = convert_messages_to_prompt(
|
||||
model, messages, provider, custom_prompt_dict
|
||||
)
|
||||
## COMPLETION CALL
|
||||
if stream is True:
|
||||
request_str = (
|
||||
"response = model.generate_text_stream(\n"
|
||||
f"\tprompt={prompt},\n"
|
||||
"\traw_response=True\n)"
|
||||
)
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
# remove params that are not needed for streaming
|
||||
del extra_generate_params["async_mode"]
|
||||
del extra_generate_params["concurrency_limit"]
|
||||
# make generate call
|
||||
response = model_client.generate_text_stream(
|
||||
prompt=prompt,
|
||||
raw_response=True,
|
||||
**extra_generate_params
|
||||
)
|
||||
return litellm.CustomStreamWrapper(
|
||||
response,
|
||||
model=model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
## LOGGING
|
||||
request_str = (
|
||||
"response = model.generate(\n"
|
||||
f"\tprompt={prompt},\n"
|
||||
"\traw_response=True\n)"
|
||||
)
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
"request_str": request_str,
|
||||
},
|
||||
)
|
||||
response = model_client.generate(
|
||||
prompt=prompt,
|
||||
**extra_generate_params
|
||||
)
|
||||
except Exception as e:
|
||||
raise WatsonxError(status_code=500, message=str(e))
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=json.dumps(response),
|
||||
additional_args={"complete_input_dict": optional_params},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## BUILD RESPONSE OBJECT
|
||||
output_text = response['results'][0]['generated_text']
|
||||
|
||||
try:
|
||||
if (
|
||||
len(output_text) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
):
|
||||
model_response["choices"][0]["message"]["content"] = output_text
|
||||
model_response["finish_reason"] = response['results'][0]['stop_reason']
|
||||
prompt_tokens = response['results'][0]['input_token_count']
|
||||
completion_tokens = response['results'][0]['generated_token_count']
|
||||
else:
|
||||
raise Exception()
|
||||
except:
|
||||
raise WatsonxError(
|
||||
message=json.dumps(output_text),
|
||||
status_code=500,
|
||||
)
|
||||
model_response['created'] = int(time.time())
|
||||
model_response['model'] = model_name
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except WatsonxError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise WatsonxError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
|
@ -63,6 +63,7 @@ from .llms import (
|
|||
vertex_ai,
|
||||
vertex_ai_anthropic,
|
||||
maritalk,
|
||||
watsonx,
|
||||
)
|
||||
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
|
||||
from .llms.azure import AzureChatCompletion
|
||||
|
@ -1858,6 +1859,43 @@ def completion(
|
|||
|
||||
## RESPONSE OBJECT
|
||||
response = response
|
||||
elif custom_llm_provider == "watsonx":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
response = watsonx.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
iter(response),
|
||||
model,
|
||||
custom_llm_provider="watsonx",
|
||||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
)
|
||||
## RESPONSE OBJECT
|
||||
response = response
|
||||
elif custom_llm_provider == "vllm":
|
||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||
model_response = vllm.completion(
|
||||
|
|
|
@ -5331,6 +5331,45 @@ def get_optional_params(
|
|||
optional_params["extra_body"] = (
|
||||
extra_body # openai client supports `extra_body` param
|
||||
)
|
||||
elif custom_llm_provider == "watsonx":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
if max_tokens is not None:
|
||||
optional_params["max_new_tokens"] = max_tokens
|
||||
if stream:
|
||||
optional_params["stream"] = stream
|
||||
if temperature is not None:
|
||||
optional_params["temperature"] = temperature
|
||||
if top_p is not None:
|
||||
optional_params["top_p"] = top_p
|
||||
if frequency_penalty is not None:
|
||||
optional_params["repetition_penalty"] = frequency_penalty
|
||||
if seed is not None:
|
||||
optional_params["random_seed"] = seed
|
||||
if stop is not None:
|
||||
optional_params["stop_sequences"] = stop
|
||||
|
||||
# WatsonX-only parameters
|
||||
extra_body = {}
|
||||
if "decoding_method" in passed_params:
|
||||
extra_body["decoding_method"] = passed_params.pop("decoding_method")
|
||||
if "min_tokens" in passed_params or "min_new_tokens" in passed_params:
|
||||
extra_body["min_new_tokens"] = passed_params.pop("min_tokens", passed_params.pop("min_new_tokens"))
|
||||
if "top_k" in passed_params:
|
||||
extra_body["top_k"] = passed_params.pop("top_k")
|
||||
if "truncate_input_tokens" in passed_params:
|
||||
extra_body["truncate_input_tokens"] = passed_params.pop("truncate_input_tokens")
|
||||
if "length_penalty" in passed_params:
|
||||
extra_body["length_penalty"] = passed_params.pop("length_penalty")
|
||||
if "time_limit" in passed_params:
|
||||
extra_body["time_limit"] = passed_params.pop("time_limit")
|
||||
if "return_options" in passed_params:
|
||||
extra_body["return_options"] = passed_params.pop("return_options")
|
||||
optional_params["extra_body"] = (
|
||||
extra_body # openai client supports `extra_body` param
|
||||
)
|
||||
else: # assume passing in params for openai/azure openai
|
||||
print_verbose(
|
||||
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
|
||||
|
@ -5688,6 +5727,8 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "watsonx":
|
||||
return litellm.IBMWatsonXConfig().get_supported_openai_params()
|
||||
|
||||
|
||||
def get_formatted_prompt(
|
||||
|
@ -5914,6 +5955,8 @@ def get_llm_provider(
|
|||
model in litellm.bedrock_models or model in litellm.bedrock_embedding_models
|
||||
):
|
||||
custom_llm_provider = "bedrock"
|
||||
elif model in litellm.watsonx_models:
|
||||
custom_llm_provider = "watsonx"
|
||||
# openai embeddings
|
||||
elif model in litellm.open_ai_embedding_models:
|
||||
custom_llm_provider = "openai"
|
||||
|
@ -9590,6 +9633,26 @@ class CustomStreamWrapper:
|
|||
"is_finished": chunk["is_finished"],
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_watsonx_stream(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
pass
|
||||
elif isinstance(chunk, str):
|
||||
chunk = json.loads(chunk)
|
||||
result = chunk.get("results", [])
|
||||
if len(result) > 0:
|
||||
text = result[0].get("generated_text", "")
|
||||
finish_reason = result[0].get("stop_reason")
|
||||
is_finished = finish_reason != 'not_finished'
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
return ""
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def model_response_creator(self):
|
||||
model_response = ModelResponse(stream=True, model=self.model)
|
||||
|
@ -9845,6 +9908,12 @@ class CustomStreamWrapper:
|
|||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "watsonx":
|
||||
response_obj = self.handle_watsonx_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "text-completion-openai":
|
||||
response_obj = self.handle_openai_text_completion_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue