diff --git a/docs/my-website/docs/completion/input.md b/docs/my-website/docs/completion/input.md index 1c2a599ca..e844c541c 100644 --- a/docs/my-website/docs/completion/input.md +++ b/docs/my-website/docs/completion/input.md @@ -60,6 +60,7 @@ Use `litellm.get_supported_openai_params()` for an updated list of params for ea |Petals| ✅ | ✅ | | ✅ | | | | | | | |Ollama| ✅ | ✅ | ✅ | ✅ | ✅ | | | ✅ | | | | | ✅ | | | |Databricks| ✅ | ✅ | ✅ | ✅ | ✅ | | | | | | | | | | | +|ClarifAI| ✅ | ✅ | | | | | | | | | | | | | | :::note diff --git a/docs/my-website/docs/providers/clarifai.md b/docs/my-website/docs/providers/clarifai.md index b1172b701..6a0bd2211 100644 --- a/docs/my-website/docs/providers/clarifai.md +++ b/docs/my-website/docs/providers/clarifai.md @@ -55,7 +55,7 @@ response = completion( ``` ## Clarifai models -liteLLM supports non-streaming requests to all models on [Clarifai community](https://clarifai.com/explore/models?filterData=%5B%7B%22field%22%3A%22use_cases%22%2C%22value%22%3A%5B%22llm%22%5D%7D%5D&page=1&perPage=24) +liteLLM supports all models on [Clarifai community](https://clarifai.com/explore/models?filterData=%5B%7B%22field%22%3A%22use_cases%22%2C%22value%22%3A%5B%22llm%22%5D%7D%5D&page=1&perPage=24) Example Usage - Note: liteLLM supports all models deployed on Clarifai diff --git a/litellm/llms/clarifai.py b/litellm/llms/clarifai.py index e07a8d9e8..4610911e1 100644 --- a/litellm/llms/clarifai.py +++ b/litellm/llms/clarifai.py @@ -14,28 +14,25 @@ class ClarifaiError(Exception): def __init__(self, status_code, message, url): self.status_code = status_code self.message = message - self.request = httpx.Request( - method="POST", url=url - ) + self.request = httpx.Request(method="POST", url=url) self.response = httpx.Response(status_code=status_code, request=self.request) - super().__init__( - self.message - ) + super().__init__(self.message) + class ClarifaiConfig: """ Reference: https://clarifai.com/meta/Llama-2/models/llama2-70b-chat - TODO fill in the details """ + max_tokens: Optional[int] = None temperature: Optional[int] = None top_k: Optional[int] = None def __init__( - self, - max_tokens: Optional[int] = None, - temperature: Optional[int] = None, - top_k: Optional[int] = None, + self, + max_tokens: Optional[int] = None, + temperature: Optional[int] = None, + top_k: Optional[int] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): @@ -60,6 +57,7 @@ class ClarifaiConfig: and v is not None } + def validate_environment(api_key): headers = { "accept": "application/json", @@ -69,42 +67,37 @@ def validate_environment(api_key): headers["Authorization"] = f"Bearer {api_key}" return headers -def completions_to_model(payload): - # if payload["n"] != 1: - # raise HTTPException( - # status_code=422, - # detail="Only one generation is supported. Please set candidate_count to 1.", - # ) - params = {} - if temperature := payload.get("temperature"): - params["temperature"] = temperature - if max_tokens := payload.get("max_tokens"): - params["max_tokens"] = max_tokens - return { - "inputs": [{"data": {"text": {"raw": payload["prompt"]}}}], - "model": {"output_info": {"params": params}}, -} - +def completions_to_model(payload): + # if payload["n"] != 1: + # raise HTTPException( + # status_code=422, + # detail="Only one generation is supported. Please set candidate_count to 1.", + # ) + + params = {} + if temperature := payload.get("temperature"): + params["temperature"] = temperature + if max_tokens := payload.get("max_tokens"): + params["max_tokens"] = max_tokens + return { + "inputs": [{"data": {"text": {"raw": payload["prompt"]}}}], + "model": {"output_info": {"params": params}}, + } + + def process_response( - model, - prompt, - response, - model_response, - api_key, - data, - encoding, - logging_obj - ): + model, prompt, response, model_response, api_key, data, encoding, logging_obj +): logging_obj.post_call( - input=prompt, - api_key=api_key, - original_response=response.text, - additional_args={"complete_input_dict": data}, - ) - ## RESPONSE OBJECT + input=prompt, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + ## RESPONSE OBJECT try: - completion_response = response.json() + completion_response = response.json() except Exception: raise ClarifaiError( message=response.text, status_code=response.status_code, url=model @@ -119,7 +112,7 @@ def process_response( message_obj = Message(content=None) choice_obj = Choices( finish_reason="stop", - index=idx + 1, #check + index=idx + 1, # check message=message_obj, ) choices_list.append(choice_obj) @@ -143,53 +136,56 @@ def process_response( ) return model_response + def convert_model_to_url(model: str, api_base: str): user_id, app_id, model_id = model.split(".") return f"{api_base}/users/{user_id}/apps/{app_id}/models/{model_id}/outputs" + def get_prompt_model_name(url: str): clarifai_model_name = url.split("/")[-2] if "claude" in clarifai_model_name: return "anthropic", clarifai_model_name.replace("_", ".") - if ("llama" in clarifai_model_name)or ("mistral" in clarifai_model_name): + if ("llama" in clarifai_model_name) or ("mistral" in clarifai_model_name): return "", "meta-llama/llama-2-chat" else: return "", clarifai_model_name + async def async_completion( - model: str, - prompt: str, - api_base: str, - custom_prompt_dict: dict, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - api_key, - logging_obj, - data=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - headers={}): - - async_handler = AsyncHTTPHandler( - timeout=httpx.Timeout(timeout=600.0, connect=5.0) - ) + model: str, + prompt: str, + api_base: str, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + api_key, + logging_obj, + data=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + headers={}, +): + + async_handler = AsyncHTTPHandler(timeout=httpx.Timeout(timeout=600.0, connect=5.0)) response = await async_handler.post( - api_base, headers=headers, data=json.dumps(data) - ) - - return process_response( - model=model, - prompt=prompt, - response=response, - model_response=model_response, - api_key=api_key, - data=data, - encoding=encoding, - logging_obj=logging_obj, + api_base, headers=headers, data=json.dumps(data) ) + return process_response( + model=model, + prompt=prompt, + response=response, + model_response=model_response, + api_key=api_key, + data=data, + encoding=encoding, + logging_obj=logging_obj, + ) + + def completion( model: str, messages: list, @@ -207,14 +203,12 @@ def completion( ): headers = validate_environment(api_key) model = convert_model_to_url(model, api_base) - prompt = " ".join(message["content"] for message in messages) # TODO + prompt = " ".join(message["content"] for message in messages) # TODO ## Load Config config = litellm.ClarifaiConfig.get_config() for k, v in config.items(): - if ( - k not in optional_params - ): + if k not in optional_params: optional_params[k] = v custom_llm_provider, orig_model_name = get_prompt_model_name(model) @@ -223,14 +217,14 @@ def completion( model=orig_model_name, messages=messages, api_key=api_key, - custom_llm_provider="clarifai" + custom_llm_provider="clarifai", ) else: prompt = prompt_factory( model=orig_model_name, messages=messages, api_key=api_key, - custom_llm_provider=custom_llm_provider + custom_llm_provider=custom_llm_provider, ) # print(prompt); exit(0) @@ -240,7 +234,6 @@ def completion( } data = completions_to_model(data) - ## LOGGING logging_obj.pre_call( input=prompt, @@ -251,7 +244,7 @@ def completion( "api_base": api_base, }, ) - if acompletion==True: + if acompletion == True: return async_completion( model=model, prompt=prompt, @@ -271,15 +264,17 @@ def completion( else: ## COMPLETION CALL response = requests.post( - model, - headers=headers, - data=json.dumps(data), - ) + model, + headers=headers, + data=json.dumps(data), + ) # print(response.content); exit() if response.status_code != 200: - raise ClarifaiError(status_code=response.status_code, message=response.text, url=model) - + raise ClarifaiError( + status_code=response.status_code, message=response.text, url=model + ) + if "stream" in optional_params and optional_params["stream"] == True: completion_stream = response.iter_lines() stream_response = CustomStreamWrapper( @@ -287,11 +282,11 @@ def completion( model=model, custom_llm_provider="clarifai", logging_obj=logging_obj, - ) + ) return stream_response - + else: - return process_response( + return process_response( model=model, prompt=prompt, response=response, @@ -299,8 +294,9 @@ def completion( api_key=api_key, data=data, encoding=encoding, - logging_obj=logging_obj) - + logging_obj=logging_obj, + ) + class ModelResponseIterator: def __init__(self, model_response): @@ -325,4 +321,4 @@ class ModelResponseIterator: if self.is_done: raise StopAsyncIteration self.is_done = True - return self.model_response \ No newline at end of file + return self.model_response