litellm-mirror/litellm/llms/nlp_cloud/chat/handler.py
Krish Dholakia 350cfc36f7
Litellm merge pr (#7161)
* build: merge branch

* test: fix openai naming

* fix(main.py): fix openai renaming

* style: ignore function length for config factory

* fix(sagemaker/): fix routing logic

* fix: fix imports

* fix: fix override
2024-12-10 22:49:26 -08:00

140 lines
3.8 KiB
Python

import json
import os
import time
import types
from enum import Enum
from typing import Any, Callable, List, Optional, Union
import httpx
import litellm
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.utils import ModelResponse, Usage
from ..common_utils import NLPCloudError
from .transformation import NLPCloudConfig
nlp_config = NLPCloudConfig()
def completion(
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
api_key,
logging_obj,
optional_params: dict,
litellm_params: dict,
logger_fn=None,
default_max_tokens_to_sample=None,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
headers={},
):
headers = nlp_config.validate_environment(
api_key=api_key,
headers=headers,
model=model,
messages=messages,
optional_params=optional_params,
)
## Load Config
config = litellm.NLPCloudConfig.get_config()
for k, v in config.items():
if (
k not in optional_params
): # completion(top_k=3) > togetherai_config(top_k=3) <- allows for dynamic variables to be passed in
optional_params[k] = v
completion_url_fragment_1 = api_base
completion_url_fragment_2 = "/generation"
model = model
completion_url = completion_url_fragment_1 + model + completion_url_fragment_2
data = nlp_config.transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
## LOGGING
logging_obj.pre_call(
input=None,
api_key=api_key,
additional_args={
"complete_input_dict": data,
"headers": headers,
"api_base": completion_url,
},
)
## COMPLETION CALL
if client is None or not isinstance(client, HTTPHandler):
client = _get_httpx_client()
response = client.post(
completion_url,
headers=headers,
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] is True:
return clean_and_iterate_chunks(response)
else:
return nlp_config.transform_response(
model=model,
raw_response=response,
model_response=model_response,
logging_obj=logging_obj,
api_key=api_key,
request_data=data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
# def clean_and_iterate_chunks(response):
# def process_chunk(chunk):
# print(f"received chunk: {chunk}")
# cleaned_chunk = chunk.decode("utf-8")
# # Perform further processing based on your needs
# return cleaned_chunk
# for line in response.iter_lines():
# if line:
# yield process_chunk(line)
def clean_and_iterate_chunks(response):
buffer = b""
for chunk in response.iter_content(chunk_size=1024):
if not chunk:
break
buffer += chunk
while b"\x00" in buffer:
buffer = buffer.replace(b"\x00", b"")
yield buffer.decode("utf-8")
buffer = b""
# No more data expected, yield any remaining data in the buffer
if buffer:
yield buffer.decode("utf-8")
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass