Merge branch 'main' into main

This commit is contained in:
Krish Dholakia 2023-12-18 17:54:34 -08:00 committed by GitHub
commit 408f232bd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 185 additions and 70 deletions

View file

@ -15,7 +15,7 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
## Pre-Requisites
Ensure you have run `pip install langfuse` for this integration
```shell
pip install langfuse litellm
pip install langfuse==1.14.0 litellm
```
## Quick Start

View file

@ -14,7 +14,7 @@ import os
os.environ['MISTRAL_API_KEY'] = ""
response = completion(
model="mistral/mistral-tiny"",
model="mistral/mistral-tiny",
messages=[
{"role": "user", "content": "hello from litellm"}
],

View file

@ -461,7 +461,7 @@ We will use the `--config` to set `litellm.success_callback = ["langfuse"]` this
**Step 1** Install langfuse
```shell
pip install langfuse
pip install langfuse==1.14.0
```
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`

View file

@ -631,24 +631,27 @@ class OpenAITextCompletion(BaseLLM):
api_key: str,
model: str):
async with httpx.AsyncClient() as client:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
try:
response = await client.post(api_base, json=data, headers=headers, timeout=litellm.request_timeout)
response_json = response.json()
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_base": api_base,
},
)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
## RESPONSE OBJECT
return self.convert_to_model_response_object(response_object=response_json, model_response_object=model_response)
except Exception as e:
raise e
def streaming(self,
logging_obj,
@ -687,9 +690,12 @@ class OpenAITextCompletion(BaseLLM):
method="POST",
timeout=litellm.request_timeout
) as response:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
try:
if response.status_code != 200:
raise OpenAIError(status_code=response.status_code, message=response.text)
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
streamwrapper = CustomStreamWrapper(completion_stream=response.aiter_lines(), model=model, custom_llm_provider="text-completion-openai",logging_obj=logging_obj)
async for transformed_chunk in streamwrapper:
yield transformed_chunk
except Exception as e:
raise e

View file

@ -0,0 +1,30 @@
from typing import List, Dict
import types
class OpenrouterConfig():
"""
Reference: https://openrouter.ai/docs#format
"""
# OpenRouter-only parameters
extra_body: Dict[str, List[str]] = {
'transforms': [] # default transforms to []
}
def __init__(self,
transforms: List[str] = [],
models: List[str] = [],
route: str = '',
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != 'self' and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {k: v for k, v in cls.__dict__.items()
if not k.startswith('__')
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
and v is not None}

View file

@ -52,6 +52,7 @@ from .llms import (
cohere,
petals,
oobabooga,
openrouter,
palm,
vertex_ai,
maritalk)
@ -260,8 +261,8 @@ def completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
functions: Optional[List] = None,
function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@ -1026,14 +1027,23 @@ def completion(
}
)
## Load Config
config = openrouter.OpenrouterConfig.get_config()
for k, v in config.items():
if k == "extra_body":
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
if "extra_body" in optional_params:
optional_params[k].update(v)
else:
optional_params[k] = v
elif k not in optional_params:
optional_params[k] = v
data = {
"model": model,
"messages": messages,
**optional_params
}
## LOGGING
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
## COMPLETION CALL
## COMPLETION CALL
response = openai_chat_completions.completion(
@ -1510,8 +1520,8 @@ def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
functions: Optional[List] = None,
function_call: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
@ -2193,10 +2203,8 @@ def text_completion(
if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response
if asyncio.iscoroutine(response):
response = asyncio.run(response)
if kwargs.get("acompletion", False) == True:
return response
transformed_logprobs = None
# only supported for TGI models
try:

View file

@ -47,7 +47,7 @@ litellm_settings:
# setting callback class
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
general_settings:
# general_settings:
environment_variables:
# otel: True # OpenTelemetry Logger

View file

@ -84,11 +84,11 @@ class Router:
self.set_verbose = set_verbose
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
if model_list:
model_list = copy.deepcopy(model_list)
self.set_model_list(model_list)
self.healthy_deployments: List = self.model_list
self.deployment_latency_map = {}
for m in model_list:
self.deployment_latency_map[m["litellm_params"]["model"]] = 0

View file

@ -169,17 +169,37 @@ def test_text_completion_stream():
# test_text_completion_stream()
async def test_text_completion_async_stream():
try:
response = await atext_completion(
model="text-completion-openai/text-davinci-003",
prompt="good morning",
stream=True,
max_tokens=10,
)
async for chunk in response:
print(f"chunk: {chunk}")
except Exception as e:
pytest.fail(f"GOT exception for HF In streaming{e}")
# async def test_text_completion_async_stream():
# try:
# response = await atext_completion(
# model="text-completion-openai/text-davinci-003",
# prompt="good morning",
# stream=True,
# max_tokens=10,
# )
# async for chunk in response:
# print(f"chunk: {chunk}")
# except Exception as e:
# pytest.fail(f"GOT exception for HF In streaming{e}")
asyncio.run(test_text_completion_async_stream())
# asyncio.run(test_text_completion_async_stream())
def test_async_text_completion():
litellm.set_verbose = True
print('test_async_text_completion')
async def test_get_response():
try:
response = await litellm.atext_completion(
model="gpt-3.5-turbo-instruct",
prompt="good morning",
stream=False,
max_tokens=10
)
print(f"response: {response}")
except litellm.Timeout as e:
print(e)
except Exception as e:
print(e)
asyncio.run(test_get_response())
test_async_text_completion()

View file

@ -2336,8 +2336,8 @@ def get_optional_params_embeddings(
def get_optional_params( # use the openai defaults
# 12 optional params
functions=[],
function_call="",
functions=None,
function_call=None,
temperature=None,
top_p=None,
n=None,
@ -2363,8 +2363,8 @@ def get_optional_params( # use the openai defaults
for k, v in special_params.items():
passed_params[k] = v
default_params = {
"functions":[],
"function_call":"",
"functions": None,
"function_call": None,
"temperature":None,
"top_p":None,
"n":None,
@ -2851,6 +2851,57 @@ def get_optional_params( # use the openai defaults
if random_seed is not None:
extra_body["random_seed"] = random_seed
optional_params["extra_body"] = extra_body # openai client supports `extra_body` param
elif custom_llm_provider == "openrouter":
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"]
_check_valid_arg(supported_params=supported_params)
if functions is not None:
optional_params["functions"] = functions
if function_call is not None:
optional_params["function_call"] = function_call
if temperature is not None:
optional_params["temperature"] = temperature
if top_p is not None:
optional_params["top_p"] = top_p
if n is not None:
optional_params["n"] = n
if stream is not None:
optional_params["stream"] = stream
if stop is not None:
optional_params["stop"] = stop
if max_tokens is not None:
optional_params["max_tokens"] = max_tokens
if presence_penalty is not None:
optional_params["presence_penalty"] = presence_penalty
if frequency_penalty is not None:
optional_params["frequency_penalty"] = frequency_penalty
if logit_bias is not None:
optional_params["logit_bias"] = logit_bias
if user is not None:
optional_params["user"] = user
if response_format is not None:
optional_params["response_format"] = response_format
if seed is not None:
optional_params["seed"] = seed
if tools is not None:
optional_params["tools"] = tools
if tool_choice is not None:
optional_params["tool_choice"] = tool_choice
if max_retries is not None:
optional_params["max_retries"] = max_retries
# OpenRouter-only parameters
extra_body = {}
transforms = passed_params.pop("transforms", None)
models = passed_params.pop("models", None)
route = passed_params.pop("route", None)
if transforms is not None:
extra_body["transforms"] = transforms
if models is not None:
extra_body["models"] = models
if route is not None:
extra_body["route"] = route
optional_params["extra_body"] = extra_body # openai client supports `extra_body` param
else: # assume passing in params for openai/azure openai
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice", "max_retries"]
_check_valid_arg(supported_params=supported_params)
@ -3962,7 +4013,7 @@ def convert_to_model_response_object(response_object: Optional[dict]=None, model
raise Exception("Error in response object format")
if model_response_object is None:
model_response_object = EmbeddingResponse()
model_response_object = ImageResponse()
if "created" in response_object:
model_response_object.created = response_object["created"]

View file

@ -3,24 +3,24 @@ anyio==4.2.0 # openai + http req.
openai>=1.0.0 # openai req.
fastapi # server dep
pydantic>=2.5 # openai req.
appdirs # server dep
backoff # server dep
pyyaml # server dep
uvicorn # server dep
boto3 # aws bedrock/sagemaker calls
redis # caching
prisma # for db
mangum # for aws lambda functions
google-generativeai # for vertex ai calls
appdirs==1.4.4 # server dep
backoff==2.2.1 # server dep
pyyaml==6.0 # server dep
uvicorn==0.22.0 # server dep
boto3==1.28.58 # aws bedrock/sagemaker calls
redis==4.6.0 # caching
prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions
google-generativeai==0.1.0 # for vertex ai calls
traceloop-sdk==0.5.3 # for open telemetry logging
langfuse==1.14.0 # for langfuse self-hosted logging
### LITELLM PACKAGE DEPENDENCIES
python-dotenv>=0.2.0 # for env
tiktoken>=0.4.0 # for calculating usage
importlib-metadata>=6.8.0 # for random utils
tokenizers # for calculating usage
click # for proxy cli
tokenizers==0.14.0 # for calculating usage
click==8.1.7 # for proxy cli
jinja2==3.1.2 # for prompt templates
certifi>=2023.7.22 # [TODO] clean up
aiohttp # for network calls
aiohttp==3.8.4 # for network calls
####