Merge pull request #4158 from BerriAI/litellm_fix_clarifai

[Fix] Add ClarifAI support for LiteLLM Proxy
This commit is contained in:
Ishaan Jaff 2024-06-12 17:17:16 -07:00 committed by GitHub
commit 06ac381d57
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 14 deletions

View file

@ -1,10 +1,13 @@
# Clarifai
Anthropic, OpenAI, Mistral, Llama and Gemini LLMs are Supported on Clarifai.
:::warning
Streaming is not yet supported on using clarifai and litellm. Tracking support here: https://github.com/BerriAI/litellm/issues/4162
:::
## Pre-Requisites
`pip install clarifai`
`pip install litellm`
## Required Environment Variables
@ -12,6 +15,7 @@ To obtain your Clarifai Personal access token follow this [link](https://docs.cl
```python
os.environ["CLARIFAI_API_KEY"] = "YOUR_CLARIFAI_PAT" # CLARIFAI_PAT
```
## Usage
@ -68,7 +72,7 @@ Example Usage - Note: liteLLM supports all models deployed on Clarifai
| clarifai/meta.Llama-2.codeLlama-70b-Python | `completion('clarifai/meta.Llama-2.codeLlama-70b-Python', messages)`|
| clarifai/meta.Llama-2.codeLlama-70b-Instruct | `completion('clarifai/meta.Llama-2.codeLlama-70b-Instruct', messages)` |
## Mistal LLMs
## Mistral LLMs
| Model Name | Function Call |
|---------------------------------------------|------------------------------------------------------------------------|
| clarifai/mistralai.completion.mixtral-8x22B | `completion('clarifai/mistralai.completion.mixtral-8x22B', messages)` |

View file

@ -139,6 +139,7 @@ def process_response(
def convert_model_to_url(model: str, api_base: str):
user_id, app_id, model_id = model.split(".")
model_id = model_id.lower()
return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs"
@ -171,19 +172,55 @@ async def async_completion(
async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0))
response = await async_handler.post(
api_base, headers=headers, data=json.dumps(data)
url=model, headers=headers, data=json.dumps(data)
)
return process_response(
model=model,
prompt=prompt,
response=response,
model_response=model_response,
logging_obj.post_call(
input=prompt,
api_key=api_key,
data=data,
encoding=encoding,
logging_obj=logging_obj,
original_response=response.text,
additional_args={"complete_input_dict": data},
)
## RESPONSE OBJECT
try:
completion_response = response.json()
except Exception:
raise ClarifaiError(
message=response.text, status_code=response.status_code, url=model
)
# print(completion_response)
try:
choices_list = []
for idx, item in enumerate(completion_response["outputs"]):
if len(item["data"]["text"]["raw"]) > 0:
message_obj = Message(content=item["data"]["text"]["raw"])
else:
message_obj = Message(content=None)
choice_obj = Choices(
finish_reason="stop",
index=idx + 1, # check
message=message_obj,
)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
# Calculate Usage
prompt_tokens = len(encoding.encode(prompt))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content"))
)
model_response["model"] = model
model_response["usage"] = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
return model_response
def completion(
@ -241,7 +278,7 @@ def completion(
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": api_base,
"api_base": model,
},
)
if acompletion == True:

View file

@ -335,6 +335,7 @@ async def acompletion(
or custom_llm_provider == "predibase"
or custom_llm_provider == "bedrock"
or custom_llm_provider == "databricks"
or custom_llm_provider == "clarifai"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
init_response = await loop.run_in_executor(None, func_with_context)