Merge pull request #3369 from mogith-pn/main

Clarifai-LiteLLM : Added clarifai as LLM Provider.
This commit is contained in:
Krish Dholakia 2024-05-11 09:31:46 -07:00 committed by GitHub
commit 8f6ae9a059
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 944 additions and 2 deletions

View file

@ -2946,6 +2946,7 @@ def client(original_function):
)
else:
return result
return result
# Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print
@ -3049,6 +3050,7 @@ def client(original_function):
model_response_object=ModelResponse(),
stream=kwargs.get("stream", False),
)
if kwargs.get("stream", False) == True:
cached_result = CustomStreamWrapper(
completion_stream=cached_result,
@ -10449,6 +10451,27 @@ class CustomStreamWrapper:
return {"text": "", "is_finished": False}
except Exception as e:
raise e
def handle_clarifai_completion_chunk(self, chunk):
try:
if isinstance(chunk, dict):
parsed_response = chunk
if isinstance(chunk, (str, bytes)):
if isinstance(chunk, bytes):
parsed_response = chunk.decode("utf-8")
else:
parsed_response = chunk
data_json = json.loads(parsed_response)
text = data_json.get("outputs", "")[0].get("data", "").get("text", "").get("raw","")
prompt_tokens = len(encoding.encode(data_json.get("outputs", "")[0].get("input","").get("data", "").get("text", "").get("raw","")))
completion_tokens = len(encoding.encode(text))
return {
"text": text,
"is_finished": True,
}
except:
traceback.print_exc()
return ""
def model_response_creator(self):
model_response = ModelResponse(
@ -10495,6 +10518,11 @@ class CustomStreamWrapper:
completion_obj["content"] = response_obj["text"]
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif (
self.custom_llm_provider and self.custom_llm_provider == "clarifai"
):
response_obj = self.handle_clarifai_completion_chunk(chunk)
completion_obj["content"] = response_obj["text"]
elif self.model == "replicate" or self.custom_llm_provider == "replicate":
response_obj = self.handle_replicate_chunk(chunk)
completion_obj["content"] = response_obj["text"]