mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 19:24:27 +00:00
* feat(base_llm): initial commit for common base config class Addresses code qa critique https://github.com/andrewyng/aisuite/issues/113#issuecomment-2512369132 * feat(base_llm/): add transform request/response abstract methods to base config class * feat(cohere-+-clarifai): refactor integrations to use common base config class * fix: fix linting errors * refactor(anthropic/): move anthropic + vertex anthropic to use base config * test: fix xai test * test: fix tests * fix: fix linting errors * test: comment out WIP test * fix(transformation.py): fix is pdf used check * fix: fix linting error
177 lines
4.5 KiB
Python
177 lines
4.5 KiB
Python
import json
|
|
import os
|
|
import time
|
|
import traceback
|
|
import types
|
|
from typing import Callable, List, Optional
|
|
|
|
import httpx
|
|
import requests
|
|
|
|
import litellm
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
AsyncHTTPHandler,
|
|
_get_httpx_client,
|
|
get_async_httpx_client,
|
|
)
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
|
|
|
from ...prompt_templates.factory import custom_prompt, prompt_factory
|
|
from ..common_utils import ClarifaiError
|
|
|
|
|
|
async def async_completion(
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
model_response: ModelResponse,
|
|
encoding,
|
|
api_key,
|
|
api_base: str,
|
|
logging_obj,
|
|
data: dict,
|
|
optional_params: dict,
|
|
litellm_params=None,
|
|
logger_fn=None,
|
|
headers={},
|
|
):
|
|
|
|
async_handler = get_async_httpx_client(
|
|
llm_provider=litellm.LlmProviders.CLARIFAI,
|
|
params={"timeout": 600.0},
|
|
)
|
|
response = await async_handler.post(
|
|
url=api_base, headers=headers, data=json.dumps(data)
|
|
)
|
|
|
|
return litellm.ClarifaiConfig().transform_response(
|
|
model=model,
|
|
raw_response=response,
|
|
model_response=model_response,
|
|
logging_obj=logging_obj,
|
|
api_key=api_key,
|
|
request_data=data,
|
|
messages=messages,
|
|
optional_params=optional_params,
|
|
encoding=encoding,
|
|
)
|
|
|
|
|
|
def completion(
|
|
model: str,
|
|
messages: list,
|
|
api_base: str,
|
|
model_response: ModelResponse,
|
|
print_verbose: Callable,
|
|
encoding,
|
|
api_key,
|
|
logging_obj,
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
custom_prompt_dict={},
|
|
acompletion=False,
|
|
logger_fn=None,
|
|
headers={},
|
|
):
|
|
headers = litellm.ClarifaiConfig().validate_environment(
|
|
api_key=api_key,
|
|
headers=headers,
|
|
model=model,
|
|
messages=messages,
|
|
optional_params=optional_params,
|
|
)
|
|
data = litellm.ClarifaiConfig().transform_request(
|
|
model=model,
|
|
messages=messages,
|
|
optional_params=optional_params,
|
|
litellm_params=litellm_params,
|
|
headers=headers,
|
|
)
|
|
|
|
## LOGGING
|
|
logging_obj.pre_call(
|
|
input=data,
|
|
api_key=api_key,
|
|
additional_args={
|
|
"complete_input_dict": data,
|
|
"headers": headers,
|
|
"api_base": model,
|
|
},
|
|
)
|
|
if acompletion is True:
|
|
return async_completion(
|
|
model=model,
|
|
messages=messages,
|
|
api_base=api_base,
|
|
model_response=model_response,
|
|
encoding=encoding,
|
|
api_key=api_key,
|
|
logging_obj=logging_obj,
|
|
data=data,
|
|
optional_params=optional_params,
|
|
litellm_params=litellm_params,
|
|
logger_fn=logger_fn,
|
|
headers=headers,
|
|
)
|
|
else:
|
|
## COMPLETION CALL
|
|
httpx_client = _get_httpx_client(
|
|
params={"timeout": 600.0},
|
|
)
|
|
response = httpx_client.post(
|
|
url=api_base,
|
|
headers=headers,
|
|
data=json.dumps(data),
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise ClarifaiError(status_code=response.status_code, message=response.text)
|
|
|
|
if "stream" in optional_params and optional_params["stream"] is True:
|
|
completion_stream = response.iter_lines()
|
|
stream_response = CustomStreamWrapper(
|
|
completion_stream=completion_stream,
|
|
model=model,
|
|
custom_llm_provider="clarifai",
|
|
logging_obj=logging_obj,
|
|
)
|
|
return stream_response
|
|
|
|
else:
|
|
return litellm.ClarifaiConfig().transform_response(
|
|
model=model,
|
|
raw_response=response,
|
|
model_response=model_response,
|
|
logging_obj=logging_obj,
|
|
api_key=api_key,
|
|
request_data=data,
|
|
messages=messages,
|
|
optional_params=optional_params,
|
|
encoding=encoding,
|
|
)
|
|
|
|
|
|
class ModelResponseIterator:
|
|
def __init__(self, model_response):
|
|
self.model_response = model_response
|
|
self.is_done = False
|
|
|
|
# Sync iterator
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.is_done:
|
|
raise StopIteration
|
|
self.is_done = True
|
|
return self.model_response
|
|
|
|
# Async iterator
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
async def __anext__(self):
|
|
if self.is_done:
|
|
raise StopAsyncIteration
|
|
self.is_done = True
|
|
return self.model_response
|