forked from phoenix/litellm-mirror
feat - watsonx refractoring, removed dependency, and added support for embedding calls
This commit is contained in:
parent
a77537ddd4
commit
74d2ba0a23
4 changed files with 477 additions and 366 deletions
|
@ -656,7 +656,7 @@ from .llms.bedrock import (
|
||||||
)
|
)
|
||||||
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
from .llms.openai import OpenAIConfig, OpenAITextCompletionConfig
|
||||||
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
from .llms.azure import AzureOpenAIConfig, AzureOpenAIError
|
||||||
from .llms.watsonx import IBMWatsonXConfig
|
from .llms.watsonx import IBMWatsonXAIConfig
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
|
|
@ -1,27 +1,31 @@
|
||||||
import json, types, time
|
import json, types, time # noqa: E401
|
||||||
from typing import Callable, Optional, Any, Union, List
|
from contextlib import contextmanager
|
||||||
|
from typing import Callable, Dict, Optional, Any, Union, List
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
import requests
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse, get_secret, Usage, ImageResponse
|
from litellm.utils import ModelResponse, get_secret, Usage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
from .prompt_templates import factory as ptf
|
from .prompt_templates import factory as ptf
|
||||||
|
|
||||||
class WatsonxError(Exception):
|
|
||||||
def __init__(self, status_code, message):
|
class WatsonXAIError(Exception):
|
||||||
|
def __init__(self, status_code, message, url: str = None):
|
||||||
self.status_code = status_code
|
self.status_code = status_code
|
||||||
self.message = message
|
self.message = message
|
||||||
self.request = httpx.Request(
|
url = url or "https://https://us-south.ml.cloud.ibm.com"
|
||||||
method="POST", url="https://https://us-south.ml.cloud.ibm.com"
|
self.request = httpx.Request(method="POST", url=url)
|
||||||
)
|
|
||||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
class IBMWatsonXConfig:
|
|
||||||
|
class IBMWatsonXAIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#deployments-text-generation
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||||
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
(See ibm_watsonx_ai.metanames.GenTextParamsMetaNames for a list of all available params)
|
||||||
|
|
||||||
Supported params for all available watsonx.ai foundational models.
|
Supported params for all available watsonx.ai foundational models.
|
||||||
|
@ -34,96 +38,64 @@ class IBMWatsonXConfig:
|
||||||
|
|
||||||
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
- `min_new_tokens` (integer): Maximum length of input tokens. Any more than this will be truncated.
|
||||||
|
|
||||||
|
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
||||||
|
|
||||||
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
- `stop_sequences` (string[]): list of strings to use as stop sequences.
|
||||||
|
|
||||||
- `time_limit` (integer): time limit in milliseconds. If the generation is not completed within the time limit, the model will return the generated text up to that point.
|
|
||||||
|
|
||||||
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
|
||||||
|
|
||||||
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
- `top_k` (integer): top k for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
|
- `top_p` (integer): top p for sampling - not available when decoding_method='greedy'.
|
||||||
|
|
||||||
- `repetition_penalty` (float): token repetition penalty during text generation.
|
- `repetition_penalty` (float): token repetition penalty during text generation.
|
||||||
|
|
||||||
- `stream` (bool): If True, the model will return a stream of responses.
|
|
||||||
|
|
||||||
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks".
|
|
||||||
|
|
||||||
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
- `truncate_input_tokens` (integer): Truncate input tokens to this length.
|
||||||
|
|
||||||
- `length_penalty` (dict): A dictionary with keys "decay_factor" and "start_index".
|
- `include_stop_sequences` (bool): If True, the stop sequence will be included at the end of the generated text in the case of a match.
|
||||||
|
|
||||||
|
- `return_options` (dict): A dictionary of options to return. Options include "input_text", "generated_tokens", "input_tokens", "token_ranks". Values are boolean.
|
||||||
|
|
||||||
- `random_seed` (integer): Random seed for text generation.
|
- `random_seed` (integer): Random seed for text generation.
|
||||||
|
|
||||||
- `guardrails` (bool): Enable guardrails for harmful content.
|
- `moderations` (dict): Dictionary of properties that control the moderations, for usages such as Hate and profanity (HAP) and PII filtering.
|
||||||
|
|
||||||
- `guardrails_hap_params` (dict): Guardrails for harmful content.
|
- `stream` (bool): If True, the model will return a stream of responses.
|
||||||
|
|
||||||
- `guardrails_pii_params` (dict): Guardrails for Personally Identifiable Information.
|
|
||||||
|
|
||||||
- `concurrency_limit` (integer): Maximum number of concurrent requests.
|
|
||||||
|
|
||||||
- `async_mode` (bool): Enable async mode.
|
|
||||||
|
|
||||||
- `verify` (bool): Verify the SSL certificate of calls to the watsonx url.
|
|
||||||
|
|
||||||
- `validate` (bool): Validate the model_id at initialization.
|
|
||||||
|
|
||||||
- `model_inference` (ibm_watsonx_ai.ModelInference): An instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance.
|
|
||||||
|
|
||||||
- `watsonx_client` (ibm_watsonx_ai.APIClient): An instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with.
|
|
||||||
"""
|
"""
|
||||||
decoding_method: Optional[str] = "sample" # 'sample' or 'greedy'. "sample" follows the default openai API behavior
|
|
||||||
temperature: Optional[float] = None #
|
decoding_method: Optional[str] = "sample"
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_new_tokens: Optional[int] = None # litellm.max_tokens
|
||||||
min_new_tokens: Optional[int] = None
|
min_new_tokens: Optional[int] = None
|
||||||
max_new_tokens: Optional[int] = litellm.max_tokens
|
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
||||||
|
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
random_seed: Optional[int] = None # e.g 42
|
|
||||||
repetition_penalty: Optional[float] = None
|
repetition_penalty: Optional[float] = None
|
||||||
stop_sequences: Optional[List[str]] = None # e.g ["}", ")", "."]
|
truncate_input_tokens: Optional[int] = None
|
||||||
time_limit: Optional[int] = None # e.g 10000 (timeout in milliseconds)
|
include_stop_sequences: Optional[bool] = False
|
||||||
return_options: Optional[dict] = None # e.g {"input_text": True, "generated_tokens": True, "input_tokens": True, "token_ranks": False}
|
return_options: Optional[dict] = None
|
||||||
truncate_input_tokens: Optional[int] = None # e.g 512
|
return_options: Optional[Dict[str, bool]] = None
|
||||||
length_penalty: Optional[dict] = None # e.g {"decay_factor": 2.5, "start_index": 5}
|
random_seed: Optional[int] = None # e.g 42
|
||||||
|
moderations: Optional[dict] = None
|
||||||
stream: Optional[bool] = False
|
stream: Optional[bool] = False
|
||||||
# other inference params
|
|
||||||
guardrails: Optional[bool] = False # enable guardrails
|
|
||||||
guardrails_hap_params: Optional[dict] = None # guardrails for harmful content
|
|
||||||
guardrails_pii_params: Optional[dict] = None # guardrails for Personally Identifiable Information
|
|
||||||
concurrency_limit: Optional[int] = 10 # max number of concurrent requests
|
|
||||||
async_mode: Optional[bool] = False # enable async mode
|
|
||||||
verify: Optional[Union[bool,str]] = None # verify the SSL certificate of calls to the watsonx url
|
|
||||||
validate: Optional[bool] = False # validate the model_id at initialization
|
|
||||||
model_inference: Optional[object] = None # an instance of an ibm_watsonx_ai.ModelInference class to use instead of creating a new model instance
|
|
||||||
watsonx_client: Optional[object] = None # an instance of an ibm_watsonx_ai.APIClient class to initialize the watsonx model with
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
decoding_method: Optional[str] = None,
|
decoding_method: Optional[str] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
|
max_new_tokens: Optional[int] = None,
|
||||||
min_new_tokens: Optional[int] = None,
|
min_new_tokens: Optional[int] = None,
|
||||||
max_new_tokens: Optional[
|
length_penalty: Optional[dict] = None,
|
||||||
int
|
stop_sequences: Optional[List[str]] = None,
|
||||||
] = litellm.max_tokens, # petals requires max tokens to be set
|
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
random_seed: Optional[int] = None,
|
|
||||||
repetition_penalty: Optional[float] = None,
|
repetition_penalty: Optional[float] = None,
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
time_limit: Optional[int] = None,
|
|
||||||
return_options: Optional[dict] = None,
|
|
||||||
truncate_input_tokens: Optional[int] = None,
|
truncate_input_tokens: Optional[int] = None,
|
||||||
length_penalty: Optional[dict] = None,
|
include_stop_sequences: Optional[bool] = None,
|
||||||
stream: Optional[bool] = False,
|
return_options: Optional[dict] = None,
|
||||||
guardrails: Optional[bool] = False,
|
random_seed: Optional[int] = None,
|
||||||
guardrails_hap_params: Optional[dict] = None,
|
moderations: Optional[dict] = None,
|
||||||
guardrails_pii_params: Optional[dict] = None,
|
stream: Optional[bool] = None,
|
||||||
concurrency_limit: Optional[int] = 10,
|
**kwargs,
|
||||||
async_mode: Optional[bool] = False,
|
|
||||||
verify: Optional[Union[bool,str]] = None,
|
|
||||||
validate: Optional[bool] = False,
|
|
||||||
model_inference: Optional[object] = None,
|
|
||||||
watsonx_client: Optional[object] = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -160,133 +132,6 @@ class IBMWatsonXConfig:
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def init_watsonx_model(
|
|
||||||
model_id: str,
|
|
||||||
url: Optional[str] = None,
|
|
||||||
api_key: Optional[str] = None,
|
|
||||||
project_id: Optional[str] = None,
|
|
||||||
space_id: Optional[str] = None,
|
|
||||||
wx_credentials: Optional[dict] = None,
|
|
||||||
region_name: Optional[str] = None,
|
|
||||||
verify: Optional[Union[bool,str]] = None,
|
|
||||||
validate: Optional[bool] = False,
|
|
||||||
watsonx_client: Optional[object] = None,
|
|
||||||
model_params: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize a watsonx.ai model for inference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
|
|
||||||
model_id (str): The model ID to use for inference. If this is a model deployed in a deployment space, the model_id should be in the format 'deployment/<deployment_id>' and the space_id to the deploymend space should be provided.
|
|
||||||
url (str): The URL of the watsonx.ai instance.
|
|
||||||
api_key (str): The API key for the watsonx.ai instance.
|
|
||||||
project_id (str): The project ID for the watsonx.ai instance.
|
|
||||||
space_id (str): The space ID for the deployment space.
|
|
||||||
wx_credentials (dict): A dictionary containing 'apikey' and 'url' keys for the watsonx.ai instance.
|
|
||||||
region_name (str): The region name for the watsonx.ai instance (e.g. 'us-south').
|
|
||||||
verify (bool): Whether to verify the SSL certificate of calls to the watsonx url.
|
|
||||||
validate (bool): Whether to validate the model_id at initialization.
|
|
||||||
watsonx_client (object): An instance of the ibm_watsonx_ai.APIClient class. If this is provided, the model will be initialized using the provided client.
|
|
||||||
model_params (dict): A dictionary containing additional parameters to pass to the model (see IBMWatsonXConfig for a list of supported parameters).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from ibm_watsonx_ai import APIClient
|
|
||||||
from ibm_watsonx_ai.foundation_models import ModelInference
|
|
||||||
|
|
||||||
|
|
||||||
if wx_credentials is not None:
|
|
||||||
if 'apikey' not in wx_credentials and 'api_key' in wx_credentials:
|
|
||||||
wx_credentials['apikey'] = wx_credentials.pop('api_key')
|
|
||||||
if 'apikey' not in wx_credentials:
|
|
||||||
raise WatsonxError(500, "Error: key 'apikey' expected in wx_credentials")
|
|
||||||
|
|
||||||
if url is None:
|
|
||||||
url = get_secret("WX_URL") or get_secret("WATSONX_URL") or get_secret("WML_URL")
|
|
||||||
if api_key is None:
|
|
||||||
api_key = get_secret("WX_API_KEY") or get_secret("WML_API_KEY")
|
|
||||||
if project_id is None:
|
|
||||||
project_id = get_secret("WX_PROJECT_ID") or get_secret("PROJECT_ID")
|
|
||||||
if region_name is None:
|
|
||||||
region_name = get_secret("WML_REGION_NAME") or get_secret("WX_REGION_NAME") or get_secret("REGION_NAME")
|
|
||||||
if space_id is None:
|
|
||||||
space_id = get_secret("WX_SPACE_ID") or get_secret("WML_DEPLOYMENT_SPACE_ID") or get_secret("SPACE_ID")
|
|
||||||
|
|
||||||
|
|
||||||
## CHECK IS 'os.environ/' passed in
|
|
||||||
# Define the list of parameters to check
|
|
||||||
params_to_check = (url, api_key, project_id, space_id, region_name)
|
|
||||||
# Iterate over parameters and update if needed
|
|
||||||
for i, param in enumerate(params_to_check):
|
|
||||||
if param and param.startswith("os.environ/"):
|
|
||||||
params_to_check[i] = get_secret(param)
|
|
||||||
# Assign updated values back to parameters
|
|
||||||
url, api_key, project_id, space_id, region_name = params_to_check
|
|
||||||
|
|
||||||
### SET WATSONX URL
|
|
||||||
if url is not None or watsonx_client is not None or wx_credentials is not None:
|
|
||||||
pass
|
|
||||||
elif region_name is not None:
|
|
||||||
url = f"https://{region_name}.ml.cloud.ibm.com"
|
|
||||||
else:
|
|
||||||
raise WatsonxError(
|
|
||||||
message="Watsonx URL not set: set WX_URL env variable or in .env file",
|
|
||||||
status_code=401,
|
|
||||||
)
|
|
||||||
if watsonx_client is not None and project_id is None:
|
|
||||||
project_id = watsonx_client.project_id
|
|
||||||
|
|
||||||
if model_id.startswith("deployment/"):
|
|
||||||
# deployment models are passed in as 'deployment/<deployment_id>'
|
|
||||||
assert space_id is not None, "space_id is required for deployment models"
|
|
||||||
deployment_id = '/'.join(model_id.split("/")[1:])
|
|
||||||
model_id = None
|
|
||||||
else:
|
|
||||||
deployment_id = None
|
|
||||||
|
|
||||||
if watsonx_client is not None:
|
|
||||||
model = ModelInference(
|
|
||||||
model_id=model_id,
|
|
||||||
params=model_params,
|
|
||||||
api_client=watsonx_client,
|
|
||||||
project_id=project_id,
|
|
||||||
deployment_id=deployment_id,
|
|
||||||
verify=verify,
|
|
||||||
validate=validate,
|
|
||||||
space_id=space_id,
|
|
||||||
)
|
|
||||||
elif wx_credentials is not None:
|
|
||||||
model = ModelInference(
|
|
||||||
model_id=model_id,
|
|
||||||
params=model_params,
|
|
||||||
credentials=wx_credentials,
|
|
||||||
project_id=project_id,
|
|
||||||
deployment_id=deployment_id,
|
|
||||||
verify=verify,
|
|
||||||
validate=validate,
|
|
||||||
space_id=space_id,
|
|
||||||
)
|
|
||||||
elif api_key is not None:
|
|
||||||
model = ModelInference(
|
|
||||||
model_id=model_id,
|
|
||||||
params=model_params,
|
|
||||||
credentials={
|
|
||||||
"apikey": api_key,
|
|
||||||
"url": url,
|
|
||||||
},
|
|
||||||
project_id=project_id,
|
|
||||||
deployment_id=deployment_id,
|
|
||||||
verify=verify,
|
|
||||||
validate=validate,
|
|
||||||
space_id=space_id,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise WatsonxError(500, "WatsonX credentials not passed or could not be found.")
|
|
||||||
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
# handle anthropic prompts and amazon titan prompts
|
# handle anthropic prompts and amazon titan prompts
|
||||||
if model in custom_prompt_dict:
|
if model in custom_prompt_dict:
|
||||||
|
@ -294,8 +139,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
model_prompt_dict = custom_prompt_dict[model]
|
model_prompt_dict = custom_prompt_dict[model]
|
||||||
prompt = ptf.custom_prompt(
|
prompt = ptf.custom_prompt(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
role_dict=model_prompt_dict.get("role_dict", model_prompt_dict.get("roles")),
|
role_dict=model_prompt_dict.get(
|
||||||
initial_prompt_value=model_prompt_dict.get("initial_prompt_value",""),
|
"role_dict", model_prompt_dict.get("roles")
|
||||||
|
),
|
||||||
|
initial_prompt_value=model_prompt_dict.get("initial_prompt_value", ""),
|
||||||
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
final_prompt_value=model_prompt_dict.get("final_prompt_value", ""),
|
||||||
bos_token=model_prompt_dict.get("bos_token", ""),
|
bos_token=model_prompt_dict.get("bos_token", ""),
|
||||||
eos_token=model_prompt_dict.get("eos_token", ""),
|
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||||
|
@ -308,18 +155,176 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||||
elif provider == "ibm-mistralai":
|
elif provider == "ibm-mistralai":
|
||||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||||
else:
|
else:
|
||||||
prompt = ptf.prompt_factory(model=model, messages=messages, custom_llm_provider='watsonx')
|
prompt = ptf.prompt_factory(
|
||||||
|
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||||
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
"""
|
class IBMWatsonXAI(BaseLLM):
|
||||||
IBM watsonx.ai AUTH Keys/Vars
|
"""
|
||||||
os.environ['WX_URL'] = ""
|
Class to interface with IBM Watsonx.ai API for text generation and embeddings.
|
||||||
os.environ['WX_API_KEY'] = ""
|
|
||||||
os.environ['WX_PROJECT_ID'] = ""
|
|
||||||
"""
|
|
||||||
|
|
||||||
def completion(
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_version = "2024-03-13"
|
||||||
|
_text_gen_endpoint = "/ml/v1/text/generation"
|
||||||
|
_text_gen_stream_endpoint = "/ml/v1/text/generation_stream"
|
||||||
|
_deployment_text_gen_endpoint = "/ml/v1/deployments/{deployment_id}/text/generation"
|
||||||
|
_deployment_text_gen_stream_endpoint = (
|
||||||
|
"/ml/v1/deployments/{deployment_id}/text/generation_stream"
|
||||||
|
)
|
||||||
|
_embeddings_endpoint = "/ml/v1/text/embeddings"
|
||||||
|
_prompts_endpoint = "/ml/v1/prompts"
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def _prepare_text_generation_req(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
prompt: str,
|
||||||
|
stream: bool,
|
||||||
|
optional_params: dict,
|
||||||
|
print_verbose: Callable = None,
|
||||||
|
) -> httpx.Request:
|
||||||
|
"""
|
||||||
|
Get the request parameters for text generation.
|
||||||
|
"""
|
||||||
|
api_params = self._get_api_params(optional_params, print_verbose=print_verbose)
|
||||||
|
# build auth headers
|
||||||
|
api_token = api_params.get("token")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
extra_body_params = optional_params.pop("extra_body", {})
|
||||||
|
optional_params.update(extra_body_params)
|
||||||
|
# init the payload to the text generation call
|
||||||
|
payload = {
|
||||||
|
"input": prompt,
|
||||||
|
"moderations": optional_params.pop("moderations", {}),
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
request_params = dict(version=api_params["api_version"])
|
||||||
|
# text generation endpoint deployment or model / stream or not
|
||||||
|
if model_id.startswith("deployment/"):
|
||||||
|
# deployment models are passed in as 'deployment/<deployment_id>'
|
||||||
|
if api_params.get("space_id") is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
url=api_params["url"],
|
||||||
|
message="Error: space_id is required for models called using the 'deployment/' endpoint. Pass in the space_id as a parameter or set it in the WX_SPACE_ID environment variable.",
|
||||||
|
)
|
||||||
|
deployment_id = "/".join(model_id.split("/")[1:])
|
||||||
|
endpoint = (
|
||||||
|
self._deployment_text_gen_stream_endpoint
|
||||||
|
if stream
|
||||||
|
else self._deployment_text_gen_endpoint
|
||||||
|
)
|
||||||
|
endpoint = endpoint.format(deployment_id=deployment_id)
|
||||||
|
else:
|
||||||
|
payload["model_id"] = model_id
|
||||||
|
payload["project_id"] = api_params["project_id"]
|
||||||
|
endpoint = (
|
||||||
|
self._text_gen_stream_endpoint if stream else self._text_gen_endpoint
|
||||||
|
)
|
||||||
|
url = api_params["url"].rstrip("/") + endpoint
|
||||||
|
return httpx.Request(
|
||||||
|
"POST", url, headers=headers, json=payload, params=request_params
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_api_params(self, params: dict, print_verbose: Callable = None) -> dict:
|
||||||
|
"""
|
||||||
|
Find watsonx.ai credentials in the params or environment variables and return the headers for authentication.
|
||||||
|
"""
|
||||||
|
# Load auth variables from params
|
||||||
|
url = params.pop("url", None)
|
||||||
|
api_key = params.pop("apikey", None)
|
||||||
|
token = params.pop("token", None)
|
||||||
|
project_id = params.pop("project_id", None) # watsonx.ai project_id
|
||||||
|
space_id = params.pop("space_id", None) # watsonx.ai deployment space_id
|
||||||
|
region_name = params.pop("region_name", params.pop("region", None))
|
||||||
|
wx_credentials = params.pop("wx_credentials", None)
|
||||||
|
api_version = params.pop("api_version", IBMWatsonXAI.api_version)
|
||||||
|
# Load auth variables from environment variables
|
||||||
|
if url is None:
|
||||||
|
url = (
|
||||||
|
get_secret("WATSONX_URL")
|
||||||
|
or get_secret("WX_URL")
|
||||||
|
or get_secret("WML_URL")
|
||||||
|
)
|
||||||
|
if api_key is None:
|
||||||
|
api_key = get_secret("WATSONX_API_KEY") or get_secret("WX_API_KEY")
|
||||||
|
if token is None:
|
||||||
|
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN")
|
||||||
|
if project_id is None:
|
||||||
|
project_id = (
|
||||||
|
get_secret("WATSONX_PROJECT_ID")
|
||||||
|
or get_secret("WX_PROJECT_ID")
|
||||||
|
or get_secret("PROJECT_ID")
|
||||||
|
)
|
||||||
|
if region_name is None:
|
||||||
|
region_name = (
|
||||||
|
get_secret("WATSONX_REGION")
|
||||||
|
or get_secret("WX_REGION")
|
||||||
|
or get_secret("REGION")
|
||||||
|
)
|
||||||
|
if space_id is None:
|
||||||
|
space_id = (
|
||||||
|
get_secret("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||||
|
or get_secret("WATSONX_SPACE_ID")
|
||||||
|
or get_secret("WX_SPACE_ID")
|
||||||
|
or get_secret("SPACE_ID")
|
||||||
|
)
|
||||||
|
|
||||||
|
# credentials parsing
|
||||||
|
if wx_credentials is not None:
|
||||||
|
url = wx_credentials.get("url", url)
|
||||||
|
api_key = wx_credentials.get(
|
||||||
|
"apikey", wx_credentials.get("api_key", api_key)
|
||||||
|
)
|
||||||
|
token = wx_credentials.get("token", token)
|
||||||
|
|
||||||
|
# verify that all required credentials are present
|
||||||
|
if url is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
message="Error: Watsonx URL not set. Set WX_URL in environment variables or pass in as a parameter.",
|
||||||
|
)
|
||||||
|
if token is None and api_key is not None:
|
||||||
|
# generate the auth token
|
||||||
|
if print_verbose:
|
||||||
|
print_verbose("Generating IAM token for Watsonx.ai")
|
||||||
|
token = self.generate_iam_token(api_key)
|
||||||
|
elif token is None and api_key is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
url=url,
|
||||||
|
message="Error: API key or token not found. Set WX_API_KEY or WX_TOKEN in environment variables or pass in as a parameter.",
|
||||||
|
)
|
||||||
|
if project_id is None:
|
||||||
|
raise WatsonXAIError(
|
||||||
|
status_code=401,
|
||||||
|
url=url,
|
||||||
|
message="Error: Watsonx project_id not set. Set WX_PROJECT_ID in environment variables or pass in as a parameter.",
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"url": url,
|
||||||
|
"api_key": api_key,
|
||||||
|
"token": token,
|
||||||
|
"project_id": project_id,
|
||||||
|
"space_id": space_id,
|
||||||
|
"region_name": region_name,
|
||||||
|
"api_version": api_version,
|
||||||
|
}
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
custom_prompt_dict: dict,
|
custom_prompt_dict: dict,
|
||||||
|
@ -327,154 +332,231 @@ def completion(
|
||||||
print_verbose: Callable,
|
print_verbose: Callable,
|
||||||
encoding,
|
encoding,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
optional_params:Optional[dict]=None,
|
optional_params: Optional[dict] = None,
|
||||||
litellm_params:Optional[dict]=None,
|
litellm_params: Optional[dict] = None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
timeout:float=None,
|
timeout: float = None,
|
||||||
):
|
):
|
||||||
from ibm_watsonx_ai.foundation_models import Model, ModelInference
|
"""
|
||||||
|
Send a text generation request to the IBM Watsonx.ai API.
|
||||||
try:
|
Reference: https://cloud.ibm.com/apidocs/watsonx-ai#text-generation
|
||||||
|
"""
|
||||||
stream = optional_params.pop("stream", False)
|
stream = optional_params.pop("stream", False)
|
||||||
extra_generate_params = dict(
|
|
||||||
guardrails=optional_params.pop("guardrails", False),
|
# Load default configs
|
||||||
guardrails_hap_params=optional_params.pop("guardrails_hap_params", None),
|
config = IBMWatsonXAIConfig.get_config()
|
||||||
guardrails_pii_params=optional_params.pop("guardrails_pii_params", None),
|
|
||||||
concurrency_limit=optional_params.pop("concurrency_limit", 10),
|
|
||||||
async_mode=optional_params.pop("async_mode", False),
|
|
||||||
)
|
|
||||||
if timeout is not None and optional_params.get("time_limit") is None:
|
|
||||||
# the time_limit in watsonx.ai is in milliseconds (as opposed to OpenAI which is in seconds)
|
|
||||||
optional_params['time_limit'] = max(0, int(timeout*1000))
|
|
||||||
extra_body_params = optional_params.pop("extra_body", {})
|
|
||||||
optional_params.update(extra_body_params)
|
|
||||||
# LOAD CONFIG
|
|
||||||
config = IBMWatsonXConfig.get_config()
|
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if k not in optional_params:
|
if k not in optional_params:
|
||||||
optional_params[k] = v
|
optional_params[k] = v
|
||||||
|
|
||||||
model_inference = optional_params.pop("model_inference", None)
|
# Make prompt to send to model
|
||||||
if model_inference is None:
|
|
||||||
# INIT MODEL
|
|
||||||
model_client:ModelInference = init_watsonx_model(
|
|
||||||
model_id=model,
|
|
||||||
url=optional_params.pop("url", None),
|
|
||||||
api_key=optional_params.pop("api_key", None),
|
|
||||||
project_id=optional_params.pop("project_id", None),
|
|
||||||
space_id=optional_params.pop("space_id", None),
|
|
||||||
wx_credentials=optional_params.pop("wx_credentials", None),
|
|
||||||
region_name=optional_params.pop("region_name", None),
|
|
||||||
verify=optional_params.pop("verify", None),
|
|
||||||
validate=optional_params.pop("validate", False),
|
|
||||||
watsonx_client=optional_params.pop("watsonx_client", None),
|
|
||||||
model_params=optional_params,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_client:ModelInference = model_inference
|
|
||||||
model = model_client.model_id
|
|
||||||
|
|
||||||
# MAKE PROMPT
|
|
||||||
provider = model.split("/")[0]
|
provider = model.split("/")[0]
|
||||||
model_name = '/'.join(model.split("/")[1:])
|
# model_name = "/".join(model.split("/")[1:])
|
||||||
prompt = convert_messages_to_prompt(
|
prompt = convert_messages_to_prompt(
|
||||||
model, messages, provider, custom_prompt_dict
|
model, messages, provider, custom_prompt_dict
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
|
||||||
if stream is True:
|
|
||||||
request_str = (
|
|
||||||
"response = model.generate_text_stream(\n"
|
|
||||||
f"\tprompt={prompt},\n"
|
|
||||||
"\traw_response=True\n)"
|
|
||||||
)
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key="",
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# remove params that are not needed for streaming
|
|
||||||
del extra_generate_params["async_mode"]
|
|
||||||
del extra_generate_params["concurrency_limit"]
|
|
||||||
# make generate call
|
|
||||||
response = model_client.generate_text_stream(
|
|
||||||
prompt=prompt,
|
|
||||||
raw_response=True,
|
|
||||||
**extra_generate_params
|
|
||||||
)
|
|
||||||
return litellm.CustomStreamWrapper(
|
|
||||||
response,
|
|
||||||
model=model,
|
|
||||||
custom_llm_provider="watsonx",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
## LOGGING
|
|
||||||
request_str = (
|
|
||||||
"response = model.generate(\n"
|
|
||||||
f"\tprompt={prompt},\n"
|
|
||||||
"\traw_response=True\n)"
|
|
||||||
)
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key="",
|
|
||||||
additional_args={
|
|
||||||
"complete_input_dict": optional_params,
|
|
||||||
"request_str": request_str,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response = model_client.generate(
|
|
||||||
prompt=prompt,
|
|
||||||
**extra_generate_params
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise WatsonxError(status_code=500, message=str(e))
|
|
||||||
|
|
||||||
## LOGGING
|
def process_text_request(request: httpx.Request) -> ModelResponse:
|
||||||
logging_obj.post_call(
|
with self._manage_response(
|
||||||
input=prompt,
|
request, logging_obj=logging_obj, input=prompt, timeout=timeout
|
||||||
api_key="",
|
) as resp:
|
||||||
original_response=json.dumps(response),
|
json_resp = resp.json()
|
||||||
additional_args={"complete_input_dict": optional_params},
|
|
||||||
)
|
|
||||||
print_verbose(f"raw model_response: {response}")
|
|
||||||
## BUILD RESPONSE OBJECT
|
|
||||||
output_text = response['results'][0]['generated_text']
|
|
||||||
|
|
||||||
try:
|
generated_text = json_resp["results"][0]["generated_text"]
|
||||||
if (
|
prompt_tokens = json_resp["results"][0]["input_token_count"]
|
||||||
len(output_text) > 0
|
completion_tokens = json_resp["results"][0]["generated_token_count"]
|
||||||
and hasattr(model_response.choices[0], "message")
|
model_response["choices"][0]["message"]["content"] = generated_text
|
||||||
):
|
model_response["finish_reason"] = json_resp["results"][0]["stop_reason"]
|
||||||
model_response["choices"][0]["message"]["content"] = output_text
|
model_response["created"] = int(time.time())
|
||||||
model_response["finish_reason"] = response['results'][0]['stop_reason']
|
model_response["model"] = model
|
||||||
prompt_tokens = response['results'][0]['input_token_count']
|
model_response.usage = Usage(
|
||||||
completion_tokens = response['results'][0]['generated_token_count']
|
|
||||||
else:
|
|
||||||
raise Exception()
|
|
||||||
except:
|
|
||||||
raise WatsonxError(
|
|
||||||
message=json.dumps(output_text),
|
|
||||||
status_code=500,
|
|
||||||
)
|
|
||||||
model_response['created'] = int(time.time())
|
|
||||||
model_response['model'] = model_name
|
|
||||||
usage = Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=completion_tokens,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
|
||||||
return model_response
|
return model_response
|
||||||
except WatsonxError as e:
|
|
||||||
|
def process_stream_request(
|
||||||
|
request: httpx.Request,
|
||||||
|
) -> litellm.CustomStreamWrapper:
|
||||||
|
# stream the response - generated chunks will be handled
|
||||||
|
# by litellm.utils.CustomStreamWrapper.handle_watsonx_stream
|
||||||
|
with self._manage_response(
|
||||||
|
request,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
stream=True,
|
||||||
|
input=prompt,
|
||||||
|
timeout=timeout,
|
||||||
|
) as resp:
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
resp.iter_lines(),
|
||||||
|
model=model,
|
||||||
|
custom_llm_provider="watsonx",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
try:
|
||||||
|
## Get the response from the model
|
||||||
|
request = self._prepare_text_generation_req(
|
||||||
|
model_id=model,
|
||||||
|
prompt=prompt,
|
||||||
|
stream=stream,
|
||||||
|
optional_params=optional_params,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return process_stream_request(request)
|
||||||
|
else:
|
||||||
|
return process_text_request(request)
|
||||||
|
except WatsonXAIError as e:
|
||||||
raise e
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise WatsonxError(status_code=500, message=str(e))
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
def embedding(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: Union[list, str],
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
logging_obj=None,
|
||||||
|
model_response=None,
|
||||||
|
optional_params=None,
|
||||||
|
encoding=None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Send a text embedding request to the IBM Watsonx.ai API.
|
||||||
|
"""
|
||||||
|
if optional_params is None:
|
||||||
|
optional_params = {}
|
||||||
|
# Load default configs
|
||||||
|
config = IBMWatsonXAIConfig.get_config()
|
||||||
|
for k, v in config.items():
|
||||||
|
if k not in optional_params:
|
||||||
|
optional_params[k] = v
|
||||||
|
|
||||||
def embedding():
|
# Load auth variables from environment variables
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
if isinstance(input, str):
|
||||||
pass
|
input = [input]
|
||||||
|
if api_key is not None:
|
||||||
|
optional_params["api_key"] = api_key
|
||||||
|
api_params = self._get_api_params(optional_params)
|
||||||
|
# build auth headers
|
||||||
|
api_token = api_params.get("token")
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_token}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
# init the payload to the text generation call
|
||||||
|
payload = {
|
||||||
|
"inputs": input,
|
||||||
|
"model_id": model,
|
||||||
|
"project_id": api_params["project_id"],
|
||||||
|
"parameters": optional_params,
|
||||||
|
}
|
||||||
|
request_params = dict(version=api_params["api_version"])
|
||||||
|
url = api_params["url"].rstrip("/") + self._embeddings_endpoint
|
||||||
|
request = httpx.Request(
|
||||||
|
"POST", url, headers=headers, json=payload, params=request_params
|
||||||
|
)
|
||||||
|
with self._manage_response(
|
||||||
|
request, logging_obj=logging_obj, input=input
|
||||||
|
) as resp:
|
||||||
|
json_resp = resp.json()
|
||||||
|
|
||||||
|
results = json_resp.get("results", [])
|
||||||
|
embedding_response = []
|
||||||
|
for idx, result in enumerate(results):
|
||||||
|
embedding_response.append(
|
||||||
|
{"object": "embedding", "index": idx, "embedding": result["embedding"]}
|
||||||
|
)
|
||||||
|
model_response["object"] = "list"
|
||||||
|
model_response["data"] = embedding_response
|
||||||
|
model_response["model"] = model
|
||||||
|
input_tokens = json_resp.get("input_token_count", 0)
|
||||||
|
model_response.usage = Usage(
|
||||||
|
prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
def generate_iam_token(self, api_key=None, **params):
|
||||||
|
headers = {}
|
||||||
|
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||||
|
if api_key is None:
|
||||||
|
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY")
|
||||||
|
if api_key is None:
|
||||||
|
raise ValueError("API key is required")
|
||||||
|
headers["Accept"] = "application/json"
|
||||||
|
data = {
|
||||||
|
"grant_type": "urn:ibm:params:oauth:grant-type:apikey",
|
||||||
|
"apikey": api_key,
|
||||||
|
}
|
||||||
|
response = httpx.post(
|
||||||
|
"https://iam.cloud.ibm.com/identity/token", data=data, headers=headers
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
json_data = response.json()
|
||||||
|
iam_access_token = json_data["access_token"]
|
||||||
|
self.token = iam_access_token
|
||||||
|
return iam_access_token
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _manage_response(
|
||||||
|
self,
|
||||||
|
request: httpx.Request,
|
||||||
|
logging_obj: Any,
|
||||||
|
stream: bool = False,
|
||||||
|
input: Optional[Any] = None,
|
||||||
|
timeout: float = None,
|
||||||
|
):
|
||||||
|
request_str = (
|
||||||
|
f"response = {request.method}(\n"
|
||||||
|
f"\turl={request.url},\n"
|
||||||
|
f"\tjson={request.content.decode()},\n"
|
||||||
|
f")"
|
||||||
|
)
|
||||||
|
json_input = json.loads(request.content.decode())
|
||||||
|
headers = dict(request.headers)
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request.headers.get("Authorization"),
|
||||||
|
additional_args={
|
||||||
|
"complete_input_dict": json_input,
|
||||||
|
"request_str": request_str,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
resp = requests.request(
|
||||||
|
method=request.method,
|
||||||
|
url=str(request.url),
|
||||||
|
headers=headers,
|
||||||
|
json=json_input,
|
||||||
|
stream=True,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
# resp.raise_for_status()
|
||||||
|
yield resp
|
||||||
|
else:
|
||||||
|
resp = requests.request(
|
||||||
|
method=request.method,
|
||||||
|
url=str(request.url),
|
||||||
|
headers=headers,
|
||||||
|
json=json_input,
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
yield resp
|
||||||
|
except Exception as e:
|
||||||
|
raise WatsonXAIError(status_code=500, message=str(e))
|
||||||
|
if not stream:
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=request.headers.get("Authorization"),
|
||||||
|
original_response=json.dumps(resp.json()),
|
||||||
|
additional_args={
|
||||||
|
"status_code": resp.status_code,
|
||||||
|
"complete_input_dict": request,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
|
@ -1862,7 +1862,7 @@ def completion(
|
||||||
response = response
|
response = response
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
||||||
response = watsonx.completion(
|
response = watsonx.IBMWatsonXAI().completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
@ -2976,6 +2976,15 @@ def embedding(
|
||||||
client=client,
|
client=client,
|
||||||
aembedding=aembedding,
|
aembedding=aembedding,
|
||||||
)
|
)
|
||||||
|
elif custom_llm_provider == "watsonx":
|
||||||
|
response = watsonx.IBMWatsonXAI().embedding(
|
||||||
|
model=model,
|
||||||
|
input=input,
|
||||||
|
encoding=encoding,
|
||||||
|
logging_obj=logging,
|
||||||
|
optional_params=optional_params,
|
||||||
|
model_response=EmbeddingResponse(),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
args = locals()
|
args = locals()
|
||||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||||
|
|
|
@ -5771,7 +5771,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"presence_penalty",
|
"presence_penalty",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "watsonx":
|
elif custom_llm_provider == "watsonx":
|
||||||
return litellm.IBMWatsonXConfig().get_supported_openai_params()
|
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
|
||||||
|
|
||||||
|
|
||||||
def get_formatted_prompt(
|
def get_formatted_prompt(
|
||||||
|
@ -9682,20 +9682,31 @@ class CustomStreamWrapper:
|
||||||
def handle_watsonx_stream(self, chunk):
|
def handle_watsonx_stream(self, chunk):
|
||||||
try:
|
try:
|
||||||
if isinstance(chunk, dict):
|
if isinstance(chunk, dict):
|
||||||
pass
|
parsed_response = chunk
|
||||||
elif isinstance(chunk, str):
|
elif isinstance(chunk, (str, bytes)):
|
||||||
chunk = json.loads(chunk)
|
if isinstance(chunk, bytes):
|
||||||
result = chunk.get("results", [])
|
chunk = chunk.decode("utf-8")
|
||||||
if len(result) > 0:
|
if 'generated_text' in chunk:
|
||||||
text = result[0].get("generated_text", "")
|
response = chunk.replace('data: ', '').strip()
|
||||||
finish_reason = result[0].get("stop_reason")
|
parsed_response = json.loads(response)
|
||||||
|
else:
|
||||||
|
return {"text": "", "is_finished": False}
|
||||||
|
else:
|
||||||
|
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||||
|
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||||
|
results = parsed_response.get("results", [])
|
||||||
|
if len(results) > 0:
|
||||||
|
text = results[0].get("generated_text", "")
|
||||||
|
finish_reason = results[0].get("stop_reason")
|
||||||
is_finished = finish_reason != 'not_finished'
|
is_finished = finish_reason != 'not_finished'
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
"finish_reason": finish_reason,
|
"finish_reason": finish_reason,
|
||||||
|
"prompt_tokens": results[0].get("input_token_count", None),
|
||||||
|
"completion_tokens": results[0].get("generated_token_count", None),
|
||||||
}
|
}
|
||||||
return ""
|
return {"text": "", "is_finished": False}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
@ -9957,6 +9968,15 @@ class CustomStreamWrapper:
|
||||||
response_obj = self.handle_watsonx_stream(chunk)
|
response_obj = self.handle_watsonx_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
if response_obj.get("prompt_tokens") is not None:
|
||||||
|
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
||||||
|
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
||||||
|
if response_obj.get("completion_tokens") is not None:
|
||||||
|
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
||||||
|
model_response.usage.total_tokens = (
|
||||||
|
getattr(model_response.usage, "prompt_tokens", 0)
|
||||||
|
+ getattr(model_response.usage, "completion_tokens", 0)
|
||||||
|
)
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
self.received_finish_reason = response_obj["finish_reason"]
|
||||||
elif self.custom_llm_provider == "text-completion-openai":
|
elif self.custom_llm_provider == "text-completion-openai":
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue