Merge pull request #4158 from BerriAI/litellm_fix_clarifai

[Fix] Add ClarifAI support for LiteLLM Proxy
This commit is contained in:
Ishaan Jaff 2024-06-12 17:17:16 -07:00 committed by GitHub
commit 06ac381d57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 14 deletions

View file

@ -139,6 +139,7 @@ def process_response(
def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".")
model_id = model_id.lower()
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
@ -171,19 +172,55 @@ async def async_completion(
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
url=model, headers=headers, data=json.dumps(data)
)
return process_response(
model=model,
prompt=prompt,
response=response,
model_response=model_response,
logging_obj.post_call(
input=prompt,
api_key=api_key,
data=data,
encoding=encoding,
logging_obj=logging_obj,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
)
# print(completion_response)
try:
choices_list = []
for idx, item in enumerate(completion_response["outputs"]):
if len(item["data"]["text"]["raw"]) > 0:
message_obj = Message(content=item["data"]["text"]["raw"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
index=idx + 1, # check
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
# Calculate Usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
model_response["model"] = model
model_response["usage"] = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def completion(
@ -241,7 +278,7 @@ def completion(
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
"api_base": model,
},
)
if acompletion == True: