feat - watsonx refractoring, removed dependency, and added support for embedding calls

This commit is contained in:
Simon Sanchez Viloria 2024-04-23 11:53:38 +02:00
parent a77537ddd4
commit 74d2ba0a23
4 changed files with 477 additions and 366 deletions

View file

@ -5771,7 +5771,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"presence_penalty",
]
elif custom_llm_provider == "watsonx":
return litellm.IBMWatsonXConfig().get_supported_openai_params()
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
def get_formatted_prompt(
@ -9682,20 +9682,31 @@ class CustomStreamWrapper:
def handle_watsonx_stream(self, chunk):
try:
if isinstance(chunk, dict):
pass
elif isinstance(chunk, str):
chunk = json.loads(chunk)
result = chunk.get("results", [])
if len(result) > 0:
text = result[0].get("generated_text", "")
finish_reason = result[0].get("stop_reason")
parsed_response = chunk
elif isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
chunk = chunk.decode("utf-8")
if 'generated_text' in chunk:
response = chunk.replace('data: ', '').strip()
parsed_response = json.loads(response)
else:
return {"text": "", "is_finished": False}
else:
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
raise ValueError(f"Unable to parse response. Original response: {chunk}")
results = parsed_response.get("results", [])
if len(results) > 0:
text = results[0].get("generated_text", "")
finish_reason = results[0].get("stop_reason")
is_finished = finish_reason != 'not_finished'
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": results[0].get("input_token_count", None),
"completion_tokens": results[0].get("generated_token_count", None),
}
return ""
return {"text": "", "is_finished": False}
except Exception as e:
raise e
@ -9957,6 +9968,15 @@ class CustomStreamWrapper:
response_obj = self.handle_watsonx_stream(chunk)
completion_obj["content"] = response_obj["text"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if response_obj.get("prompt_tokens") is not None:
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
if response_obj.get("completion_tokens") is not None:
model_response.usage.completion_tokens = response_obj["completion_tokens"]
model_response.usage.total_tokens = (
getattr(model_response.usage, "prompt_tokens", 0)
+ getattr(model_response.usage, "completion_tokens", 0)
)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif self.custom_llm_provider == "text-completion-openai":