litellm-mirror/litellm/llms/clarifai/chat/handler.py
Krish Dholakia 501885d653 Litellm code qa common config (#7113)
* feat(base_llm): initial commit for common base config class

Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132

* feat(base_llm/): add transform request/response abstract methods to base config class

* feat(cohere-+-clarifai): refactor integrations to use common base config class

* fix: fix linting errors

* refactor(anthropic/): move anthropic + vertex anthropic to use base config

* test: fix xai test

* test: fix tests

* fix: fix linting errors

* test: comment out WIP test

* fix(transformation.py): fix is pdf used check

* fix: fix linting error
2024-12-09 15:58:25 -08:00

177 lines
4.5 KiB
Python

import json
import os
import time
import traceback
import types
from typing import Callable, List, Optional
import httpx
import requests
import litellm
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
from ...prompt_templates.factory import custom_prompt, prompt_factory
from ..common_utils import ClarifaiError
async def async_completion(
model: str,
messages: List[AllMessageValues],
model_response: ModelResponse,
encoding,
api_key,
api_base: str,
logging_obj,
data: dict,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
):
async_handler = get_async_httpx_client(
llm_provider=litellm.LlmProviders.CLARIFAI,
params={"timeout": 600.0},
)
response = await async_handler.post(
url=api_base, headers=headers, data=json.dumps(data)
)
return litellm.ClarifaiConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
custom_prompt_dict={},
acompletion=False,
logger_fn=None,
headers={},
):
headers = litellm.ClarifaiConfig().validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
data = litellm.ClarifaiConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=data,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": model,
},
)
if acompletion is True:
return async_completion(
model=model,
messages=messages,
api_base=api_base,
model_response=model_response,
encoding=encoding,
api_key=api_key,
logging_obj=logging_obj,
data=data,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
)
else:
## COMPLETION CALL
httpx_client = _get_httpx_client(
params={"timeout": 600.0},
)
response = httpx_client.post(
url=api_base,
headers=headers,
data=json.dumps(data),
)
if response.status_code != 200:
raise ClarifaiError(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] is True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="clarifai",
logging_obj=logging_obj,
)
return stream_response
else:
return litellm.ClarifaiConfig().transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
class ModelResponseIterator:
def __init__(self, model_response):
self.model_response = model_response
self.is_done = False
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.is_done:
raise StopIteration
self.is_done = True
return self.model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.is_done:
raise StopAsyncIteration
self.is_done = True
return self.model_response