forked from phoenix/litellm-mirror
feat - watsonx refractoring, removed dependency, and added support for embedding calls
This commit is contained in:
parent
a77537ddd4
commit
74d2ba0a23
4 changed files with 477 additions and 366 deletions
|
@ -5771,7 +5771,7 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
|||
"presence_penalty",
|
||||
]
|
||||
elif custom_llm_provider == "watsonx":
|
||||
return litellm.IBMWatsonXConfig().get_supported_openai_params()
|
||||
return litellm.IBMWatsonXAIConfig().get_supported_openai_params()
|
||||
|
||||
|
||||
def get_formatted_prompt(
|
||||
|
@ -9682,20 +9682,31 @@ class CustomStreamWrapper:
|
|||
def handle_watsonx_stream(self, chunk):
|
||||
try:
|
||||
if isinstance(chunk, dict):
|
||||
pass
|
||||
elif isinstance(chunk, str):
|
||||
chunk = json.loads(chunk)
|
||||
result = chunk.get("results", [])
|
||||
if len(result) > 0:
|
||||
text = result[0].get("generated_text", "")
|
||||
finish_reason = result[0].get("stop_reason")
|
||||
parsed_response = chunk
|
||||
elif isinstance(chunk, (str, bytes)):
|
||||
if isinstance(chunk, bytes):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if 'generated_text' in chunk:
|
||||
response = chunk.replace('data: ', '').strip()
|
||||
parsed_response = json.loads(response)
|
||||
else:
|
||||
return {"text": "", "is_finished": False}
|
||||
else:
|
||||
print_verbose(f"chunk: {chunk} (Type: {type(chunk)})")
|
||||
raise ValueError(f"Unable to parse response. Original response: {chunk}")
|
||||
results = parsed_response.get("results", [])
|
||||
if len(results) > 0:
|
||||
text = results[0].get("generated_text", "")
|
||||
finish_reason = results[0].get("stop_reason")
|
||||
is_finished = finish_reason != 'not_finished'
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": results[0].get("input_token_count", None),
|
||||
"completion_tokens": results[0].get("generated_token_count", None),
|
||||
}
|
||||
return ""
|
||||
return {"text": "", "is_finished": False}
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -9957,6 +9968,15 @@ class CustomStreamWrapper:
|
|||
response_obj = self.handle_watsonx_stream(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if response_obj.get("prompt_tokens") is not None:
|
||||
prompt_token_count = getattr(model_response.usage, "prompt_tokens", 0)
|
||||
model_response.usage.prompt_tokens = (prompt_token_count+response_obj["prompt_tokens"])
|
||||
if response_obj.get("completion_tokens") is not None:
|
||||
model_response.usage.completion_tokens = response_obj["completion_tokens"]
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "prompt_tokens", 0)
|
||||
+ getattr(model_response.usage, "completion_tokens", 0)
|
||||
)
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "text-completion-openai":
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue