forked from phoenix/litellm-mirror
Merge branch 'main' into public-fix-1
This commit is contained in:
commit
47ba8082df
109 changed files with 8257 additions and 3200 deletions
|
@ -79,6 +79,11 @@ jobs:
|
||||||
steps:
|
steps:
|
||||||
- checkout
|
- checkout
|
||||||
|
|
||||||
|
- run:
|
||||||
|
name: Copy model_prices_and_context_window File to model_prices_and_context_window_backup
|
||||||
|
command: |
|
||||||
|
cp model_prices_and_context_window.json litellm/model_prices_and_context_window_backup.json
|
||||||
|
|
||||||
- run:
|
- run:
|
||||||
name: Check if litellm dir was updated or if pyproject.toml was modified
|
name: Check if litellm dir was updated or if pyproject.toml was modified
|
||||||
command: |
|
command: |
|
||||||
|
|
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -19,3 +19,9 @@ litellm/proxy/_secret_config.yaml
|
||||||
litellm/tests/aiologs.log
|
litellm/tests/aiologs.log
|
||||||
litellm/tests/exception_data.txt
|
litellm/tests/exception_data.txt
|
||||||
litellm/tests/config_*.yaml
|
litellm/tests/config_*.yaml
|
||||||
|
litellm/tests/langfuse.log
|
||||||
|
litellm/tests/test_custom_logger.py
|
||||||
|
litellm/tests/langfuse.log
|
||||||
|
litellm/tests/dynamo*.log
|
||||||
|
.vscode/settings.json
|
||||||
|
litellm/proxy/log.txt
|
||||||
|
|
31
Dockerfile
31
Dockerfile
|
@ -1,8 +1,11 @@
|
||||||
# Base image
|
# Base image
|
||||||
ARG LITELLM_BASE_IMAGE=python:3.9-slim
|
ARG LITELLM_BUILD_IMAGE=python:3.9
|
||||||
|
|
||||||
# allow users to specify, else use python 3.9-slim
|
# Runtime image
|
||||||
FROM $LITELLM_BASE_IMAGE
|
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
|
||||||
|
|
||||||
|
# allow users to specify, else use python 3.9
|
||||||
|
FROM $LITELLM_BUILD_IMAGE as builder
|
||||||
|
|
||||||
# Set the working directory to /app
|
# Set the working directory to /app
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
@ -16,7 +19,7 @@ RUN pip install --upgrade pip && \
|
||||||
pip install build
|
pip install build
|
||||||
|
|
||||||
# Copy the current directory contents into the container at /app
|
# Copy the current directory contents into the container at /app
|
||||||
COPY . /app
|
COPY requirements.txt .
|
||||||
|
|
||||||
# Build the package
|
# Build the package
|
||||||
RUN rm -rf dist/* && python -m build
|
RUN rm -rf dist/* && python -m build
|
||||||
|
@ -25,13 +28,27 @@ RUN rm -rf dist/* && python -m build
|
||||||
RUN pip install dist/*.whl
|
RUN pip install dist/*.whl
|
||||||
|
|
||||||
# Install any needed packages specified in requirements.txt
|
# Install any needed packages specified in requirements.txt
|
||||||
RUN pip wheel --no-cache-dir --wheel-dir=wheels -r requirements.txt
|
RUN pip install wheel && \
|
||||||
RUN pip install --no-cache-dir --find-links=wheels -r requirements.txt
|
pip wheel --no-cache-dir --wheel-dir=/app/wheels -r requirements.txt
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
FROM $LITELLM_RUNTIME_IMAGE as runtime
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy the current directory contents into the container at /app
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
COPY --from=builder /app/wheels /app/wheels
|
||||||
|
|
||||||
|
RUN pip install --no-index --find-links=/app/wheels -r requirements.txt
|
||||||
|
|
||||||
|
# Trigger the Prisma CLI to be installed
|
||||||
|
RUN prisma -v
|
||||||
|
|
||||||
EXPOSE 4000/tcp
|
EXPOSE 4000/tcp
|
||||||
|
|
||||||
# Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy
|
# Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy
|
||||||
|
|
||||||
# Start the litellm proxy with default options
|
# Start the litellm proxy with default options
|
||||||
CMD ["--port", "4000"]
|
CMD ["--port", "4000"]
|
||||||
|
|
||||||
|
|
17
README.md
17
README.md
|
@ -62,6 +62,22 @@ response = completion(model="command-nightly", messages=messages)
|
||||||
print(response)
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Async ([Docs](https://docs.litellm.ai/docs/completion/stream#async-completion))
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import acompletion
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
async def test_get_response():
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||||
|
return response
|
||||||
|
|
||||||
|
response = asyncio.run(test_get_response())
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream))
|
## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream))
|
||||||
liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response.
|
liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response.
|
||||||
Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.)
|
Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.)
|
||||||
|
@ -140,6 +156,7 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
|
||||||
| [openrouter](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | ✅ |
|
| [openrouter](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [google - vertex_ai](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
|
| [google - vertex_ai](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
|
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [ai21](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | ✅ |
|
| [ai21](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [baseten](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | ✅ |
|
| [baseten](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | ✅ |
|
||||||
| [vllm](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | ✅ |
|
| [vllm](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | ✅ |
|
||||||
|
|
BIN
dist/litellm-1.12.5.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.5.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.5.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.5.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev2-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev2-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev2.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev2.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev3-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev3-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev3.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev3.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev4-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev4-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev4.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev4.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev5-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev5-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev5.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev5.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.0.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.14.0.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.0.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.14.0.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.5.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.14.5.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.5.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.14.5.dev1.tar.gz
vendored
Normal file
Binary file not shown.
|
@ -55,27 +55,76 @@ litellm.cache = cache # set litellm.cache to your cache
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Detecting Cached Responses
|
## Cache Initialization Parameters
|
||||||
For resposes that were returned as cache hit, the response includes a param `cache` = True
|
|
||||||
|
|
||||||
:::info
|
#### `type` (str, optional)
|
||||||
|
|
||||||
Only valid for OpenAI <= 0.28.1 [Let us know if you still need this](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+)
|
The type of cache to initialize. It can be either "local" or "redis". Defaults to "local".
|
||||||
:::
|
|
||||||
|
|
||||||
Example response with cache hit
|
#### `host` (str, optional)
|
||||||
```python
|
|
||||||
{
|
|
||||||
'cache': True,
|
|
||||||
'id': 'chatcmpl-7wggdzd6OXhgE2YhcLJHJNZsEWzZ2',
|
|
||||||
'created': 1694221467,
|
|
||||||
'model': 'gpt-3.5-turbo-0613',
|
|
||||||
'choices': [
|
|
||||||
{
|
|
||||||
'index': 0, 'message': {'role': 'assistant', 'content': 'I\'m sorry, but I couldn\'t find any information about "litellm" or how many stars it has. It is possible that you may be referring to a specific product, service, or platform that I am not familiar with. Can you please provide more context or clarify your question?'
|
|
||||||
}, 'finish_reason': 'stop'}
|
|
||||||
],
|
|
||||||
'usage': {'prompt_tokens': 17, 'completion_tokens': 59, 'total_tokens': 76},
|
|
||||||
}
|
|
||||||
|
|
||||||
```
|
The host address for the Redis cache. This parameter is required if the `type` is set to "redis".
|
||||||
|
|
||||||
|
#### `port` (int, optional)
|
||||||
|
|
||||||
|
The port number for the Redis cache. This parameter is required if the `type` is set to "redis".
|
||||||
|
|
||||||
|
#### `password` (str, optional)
|
||||||
|
|
||||||
|
The password for the Redis cache. This parameter is required if the `type` is set to "redis".
|
||||||
|
|
||||||
|
#### `supported_call_types` (list, optional)
|
||||||
|
|
||||||
|
A list of call types to cache for. Defaults to caching for all call types. The available call types are:
|
||||||
|
|
||||||
|
- "completion"
|
||||||
|
- "acompletion"
|
||||||
|
- "embedding"
|
||||||
|
- "aembedding"
|
||||||
|
|
||||||
|
#### `**kwargs` (additional keyword arguments)
|
||||||
|
|
||||||
|
Additional keyword arguments are accepted for the initialization of the Redis cache using the `redis.Redis()` constructor. These arguments allow you to fine-tune the Redis cache configuration based on your specific needs.
|
||||||
|
|
||||||
|
|
||||||
|
## Logging
|
||||||
|
|
||||||
|
Cache hits are logged in success events as `kwarg["cache_hit"]`.
|
||||||
|
|
||||||
|
Here's an example of accessing it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm import completion, acompletion, Cache
|
||||||
|
|
||||||
|
# create custom callback for success_events
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
print(f"Value of Cache hit: {kwargs['cache_hit']"})
|
||||||
|
|
||||||
|
async def test_async_completion_azure_caching():
|
||||||
|
# set custom callback
|
||||||
|
customHandler_caching = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_caching]
|
||||||
|
|
||||||
|
# init cache
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
unique_time = time.time()
|
||||||
|
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||||
|
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
```
|
||||||
|
|
|
@ -4,7 +4,9 @@
|
||||||
You can create a custom callback class to precisely log events as they occur in litellm.
|
You can create a custom callback class to precisely log events as they occur in litellm.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
import litellm
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm import completion, acompletion
|
||||||
|
|
||||||
class MyCustomHandler(CustomLogger):
|
class MyCustomHandler(CustomLogger):
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
@ -22,13 +24,37 @@ class MyCustomHandler(CustomLogger):
|
||||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Failure")
|
print(f"On Failure")
|
||||||
|
|
||||||
|
#### ASYNC #### - for acompletion/aembeddings
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Async Streaming")
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Async Success")
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Async Success")
|
||||||
|
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
|
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
## sync
|
||||||
response = completion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
response = completion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||||
stream=True)
|
stream=True)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
## async
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
def async completion():
|
||||||
|
response = await acompletion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
asyncio.run(completion())
|
||||||
```
|
```
|
||||||
|
|
||||||
## Callback Functions
|
## Callback Functions
|
||||||
|
@ -87,6 +113,41 @@ print(response)
|
||||||
|
|
||||||
## Async Callback Functions
|
## Async Callback Functions
|
||||||
|
|
||||||
|
We recommend using the Custom Logger class for async.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm import acompletion
|
||||||
|
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
#### ASYNC ####
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Async Streaming")
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Async Success")
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Async Failure")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
|
||||||
|
def async completion():
|
||||||
|
response = await acompletion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
asyncio.run(completion())
|
||||||
|
```
|
||||||
|
|
||||||
|
**Functions**
|
||||||
|
|
||||||
|
If you just want to pass in an async function for logging.
|
||||||
|
|
||||||
LiteLLM currently supports just async success callback functions for async completion/embedding calls.
|
LiteLLM currently supports just async success callback functions for async completion/embedding calls.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -117,9 +178,6 @@ asyncio.run(test_chat_openai())
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007)
|
We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## What's in kwargs?
|
## What's in kwargs?
|
||||||
|
@ -170,6 +228,48 @@ Here's exactly what you can expect in the kwargs dictionary:
|
||||||
"end_time" = end_time # datetime object of when call was completed
|
"end_time" = end_time # datetime object of when call was completed
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Cache hits
|
||||||
|
|
||||||
|
Cache hits are logged in success events as `kwarg["cache_hit"]`.
|
||||||
|
|
||||||
|
Here's an example of accessing it:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm import completion, acompletion, Cache
|
||||||
|
|
||||||
|
class MyCustomHandler(CustomLogger):
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Success")
|
||||||
|
print(f"Value of Cache hit: {kwargs['cache_hit']"})
|
||||||
|
|
||||||
|
async def test_async_completion_azure_caching():
|
||||||
|
customHandler_caching = MyCustomHandler()
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
litellm.callbacks = [customHandler_caching]
|
||||||
|
unique_time = time.time()
|
||||||
|
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||||
|
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||||
|
assert len(customHandler_caching.errors) == 0
|
||||||
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||||
|
```
|
||||||
|
|
||||||
### Get complete streaming response
|
### Get complete streaming response
|
||||||
|
|
||||||
LiteLLM will pass you the complete streaming response in the final streaming chunk as part of the kwargs for your custom callback function.
|
LiteLLM will pass you the complete streaming response in the final streaming chunk as part of the kwargs for your custom callback function.
|
||||||
|
|
|
@ -27,8 +27,8 @@ To get better visualizations on how your code behaves, you may want to annotate
|
||||||
|
|
||||||
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
|
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
|
||||||
|
|
||||||
Since Traceloop SDK uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [Traceloop docs on exporters](https://traceloop.com/docs/python-sdk/exporters) for more information.
|
Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
|
||||||
|
|
||||||
## Support
|
## Support
|
||||||
|
|
||||||
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://join.slack.com/t/traceloopcommunity/shared_invite/zt-1plpfpm6r-zOHKI028VkpcWdobX65C~g) or via [email](mailto:dev@traceloop.com).
|
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
|
||||||
|
|
21
docs/my-website/docs/projects/Docq.AI.md
Normal file
21
docs/my-website/docs/projects/Docq.AI.md
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
**A private and secure ChatGPT alternative that knows your business.**
|
||||||
|
|
||||||
|
Upload docs, ask questions --> get answers.
|
||||||
|
|
||||||
|
Leverage GenAI with your confidential documents to increase efficiency and collaboration.
|
||||||
|
|
||||||
|
OSS core, everything can run in your environment. An extensible platform you can build your GenAI strategy on. Support a variety of popular LLMs including embedded for air gap use cases.
|
||||||
|
|
||||||
|
[![Static Badge][docs-shield]][docs-url]
|
||||||
|
[![Static Badge][github-shield]][github-url]
|
||||||
|
[![X (formerly Twitter) Follow][twitter-shield]][twitter-url]
|
||||||
|
|
||||||
|
<!-- MARKDOWN LINKS & IMAGES -->
|
||||||
|
<!-- https://www.markdownguide.org/basic-syntax/#reference-style-links -->
|
||||||
|
|
||||||
|
[docs-shield]: https://img.shields.io/badge/docs-site-black?logo=materialformkdocs
|
||||||
|
[docs-url]: https://docqai.github.io/docq/
|
||||||
|
[github-shield]: https://img.shields.io/badge/Github-repo-black?logo=github
|
||||||
|
[github-url]: https://github.com/docqai/docq/
|
||||||
|
[twitter-shield]: https://img.shields.io/twitter/follow/docqai?logo=x&style=flat
|
||||||
|
[twitter-url]: https://twitter.com/docqai
|
56
docs/my-website/docs/providers/mistral.md
Normal file
56
docs/my-website/docs/providers/mistral.md
Normal file
|
@ -0,0 +1,56 @@
|
||||||
|
# Mistral AI API
|
||||||
|
https://docs.mistral.ai/api/
|
||||||
|
|
||||||
|
## API Key
|
||||||
|
```python
|
||||||
|
# env variable
|
||||||
|
os.environ['MISTRAL_API_KEY']
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['MISTRAL_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="mistral/mistral-tiny"",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Sample Usage - Streaming
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ['MISTRAL_API_KEY'] = ""
|
||||||
|
response = completion(
|
||||||
|
model="mistral/mistral-tiny",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "hello from litellm"}
|
||||||
|
],
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
||||||
|
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||||
|
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` |
|
||||||
|
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
|
||||||
|
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
47
docs/my-website/docs/providers/openai_compatible.md
Normal file
47
docs/my-website/docs/providers/openai_compatible.md
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# OpenAI-Compatible Endpoints
|
||||||
|
|
||||||
|
To call models hosted behind an openai proxy, make 2 changes:
|
||||||
|
|
||||||
|
1. Put `openai/` in front of your model name, so litellm knows you're trying to call an openai-compatible endpoint.
|
||||||
|
|
||||||
|
2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding
|
||||||
|
litellm.set_verbose = True
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
litellm_proxy_endpoint = "http://0.0.0.0:8000"
|
||||||
|
bearer_token = "sk-1234"
|
||||||
|
|
||||||
|
CHOSEN_LITE_LLM_EMBEDDING_MODEL = "openai/GPT-J 6B - Sagemaker Text Embedding (Internal)"
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
|
|
||||||
|
print(litellm_proxy_endpoint)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
response = embedding(
|
||||||
|
|
||||||
|
model = CHOSEN_LITE_LLM_EMBEDDING_MODEL, # add `openai/` prefix to model so litellm knows to route to OpenAI
|
||||||
|
|
||||||
|
api_key=bearer_token,
|
||||||
|
|
||||||
|
api_base=litellm_proxy_endpoint, # set API Base of your Custom OpenAI Endpoint
|
||||||
|
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
|
||||||
|
api_version='2023-07-01-preview'
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
print('================================================')
|
||||||
|
|
||||||
|
print(len(response.data[0]['embedding']))
|
||||||
|
|
||||||
|
```
|
|
@ -1,4 +1,4 @@
|
||||||
# VertexAI - Google
|
# VertexAI - Google [Gemini]
|
||||||
|
|
||||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
|
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
|
||||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||||
|
@ -10,6 +10,16 @@
|
||||||
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
|
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
|
||||||
* Alternatively you can set `application_default_credentials.json`
|
* Alternatively you can set `application_default_credentials.json`
|
||||||
|
|
||||||
|
|
||||||
|
## Sample Usage
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||||
|
litellm.vertex_location = "us-central1" # proj location
|
||||||
|
|
||||||
|
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
|
||||||
|
```
|
||||||
|
|
||||||
## Set Vertex Project & Vertex Location
|
## Set Vertex Project & Vertex Location
|
||||||
All calls using Vertex AI require the following parameters:
|
All calls using Vertex AI require the following parameters:
|
||||||
* Your Project ID
|
* Your Project ID
|
||||||
|
@ -37,13 +47,50 @@ os.environ["VERTEXAI_LOCATION"] = "us-central1 # Your Location
|
||||||
litellm.vertex_location = "us-central1 # Your Location
|
litellm.vertex_location = "us-central1 # Your Location
|
||||||
```
|
```
|
||||||
|
|
||||||
## Sample Usage
|
## Gemini Pro
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|------------------|--------------------------------------|
|
||||||
|
| gemini-pro | `completion('gemini-pro', messages)` |
|
||||||
|
|
||||||
|
## Gemini Pro Vision
|
||||||
|
| Model Name | Function Call |
|
||||||
|
|------------------|--------------------------------------|
|
||||||
|
| gemini-pro-vision | `completion('gemini-pro-vision', messages)` |
|
||||||
|
|
||||||
|
#### Using Gemini Pro Vision
|
||||||
|
|
||||||
|
Call `gemini-pro-vision` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models)
|
||||||
|
|
||||||
|
LiteLLM Supports the following image types passed in `url`
|
||||||
|
- Images with Cloud Storage URIs - gs://cloud-samples-data/generative-ai/image/boats.jpeg
|
||||||
|
- Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg
|
||||||
|
- Videos with Cloud Storage URIs - https://storage.googleapis.com/github-repo/img/gemini/multimodality_usecases_overview/pixel8.mp4
|
||||||
|
|
||||||
|
**Example Request**
|
||||||
```python
|
```python
|
||||||
import litellm
|
import litellm
|
||||||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
|
||||||
litellm.vertex_location = "us-central1" # proj location
|
|
||||||
|
|
||||||
response = completion(model="chat-bison", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
|
response = litellm.completion(
|
||||||
|
model = "vertex_ai/gemini-pro-vision",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Whats in this image?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
```
|
```
|
||||||
|
|
||||||
## Chat Models
|
## Chat Models
|
||||||
|
|
|
@ -1,20 +1,24 @@
|
||||||
# Caching
|
# Caching
|
||||||
Cache LLM Responses
|
Cache LLM Responses
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
||||||
#### Step 1: Add `cache` to the config.yaml
|
### Step 1: Add `cache` to the config.yaml
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
|
- model_name: text-embedding-ada-002
|
||||||
|
litellm_params:
|
||||||
|
model: text-embedding-ada-002
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
set_verbose: True
|
set_verbose: True
|
||||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 2: Add Redis Credentials to .env
|
### Step 2: Add Redis Credentials to .env
|
||||||
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
|
@ -32,12 +36,12 @@ REDIS_<redis-kwarg-name> = ""
|
||||||
```
|
```
|
||||||
|
|
||||||
[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40)
|
[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40)
|
||||||
#### Step 3: Run proxy with config
|
### Step 3: Run proxy with config
|
||||||
```shell
|
```shell
|
||||||
$ litellm --config /path/to/config.yaml
|
$ litellm --config /path/to/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Using Caching
|
## Using Caching - /chat/completions
|
||||||
Send the same request twice:
|
Send the same request twice:
|
||||||
```shell
|
```shell
|
||||||
curl http://0.0.0.0:8000/v1/chat/completions \
|
curl http://0.0.0.0:8000/v1/chat/completions \
|
||||||
|
@ -57,9 +61,51 @@ curl http://0.0.0.0:8000/v1/chat/completions \
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Control caching per completion request
|
## Using Caching - /embeddings
|
||||||
|
Send the same request twice:
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
|
||||||
|
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Advanced
|
||||||
|
### Set Cache Params on config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
- model_name: text-embedding-ada-002
|
||||||
|
litellm_params:
|
||||||
|
model: text-embedding-ada-002
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
set_verbose: True
|
||||||
|
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||||
|
cache_params: # cache_params are optional
|
||||||
|
type: "redis" # The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
||||||
|
host: "localhost" # The host address for the Redis cache. Required if type is "redis".
|
||||||
|
port: 6379 # The port number for the Redis cache. Required if type is "redis".
|
||||||
|
password: "your_password" # The password for the Redis cache. Required if type is "redis".
|
||||||
|
|
||||||
|
# Optional configurations
|
||||||
|
supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types
|
||||||
|
```
|
||||||
|
|
||||||
|
### Override caching per `chat/completions` request
|
||||||
Caching can be switched on/off per `/chat/completions` request
|
Caching can be switched on/off per `/chat/completions` request
|
||||||
- Caching **on** for completion - pass `caching=True`:
|
- Caching **on** for individual completion - pass `caching=True`:
|
||||||
```shell
|
```shell
|
||||||
curl http://0.0.0.0:8000/v1/chat/completions \
|
curl http://0.0.0.0:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -70,7 +116,7 @@ Caching can be switched on/off per `/chat/completions` request
|
||||||
"caching": true
|
"caching": true
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
- Caching **off** for completion - pass `caching=False`:
|
- Caching **off** for individual completion - pass `caching=False`:
|
||||||
```shell
|
```shell
|
||||||
curl http://0.0.0.0:8000/v1/chat/completions \
|
curl http://0.0.0.0:8000/v1/chat/completions \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
|
@ -81,3 +127,28 @@ Caching can be switched on/off per `/chat/completions` request
|
||||||
"caching": false
|
"caching": false
|
||||||
}'
|
}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Override caching per `/embeddings` request
|
||||||
|
|
||||||
|
Caching can be switched on/off per `/embeddings` request
|
||||||
|
- Caching **on** for embedding - pass `caching=True`:
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": ["write a litellm poem"],
|
||||||
|
"caching": true
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
- Caching **off** for completion - pass `caching=False`:
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": ["write a litellm poem"],
|
||||||
|
"caching": false
|
||||||
|
}'
|
||||||
|
```
|
78
docs/my-website/docs/proxy/call_hooks.md
Normal file
78
docs/my-website/docs/proxy/call_hooks.md
Normal file
|
@ -0,0 +1,78 @@
|
||||||
|
# Modify Incoming Data
|
||||||
|
|
||||||
|
Modify data just before making litellm completion calls call on proxy
|
||||||
|
|
||||||
|
See a complete example with our [parallel request rate limiter](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/parallel_request_limiter.py)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
1. In your Custom Handler add a new `async_pre_call_hook` function
|
||||||
|
|
||||||
|
This function is called just before a litellm completion call is made, and allows you to modify the data going into the litellm call [**See Code**](https://github.com/BerriAI/litellm/blob/589a6ca863000ba8e92c897ba0f776796e7a5904/litellm/proxy/proxy_server.py#L1000)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
# This file includes the custom callbacks for LiteLLM Proxy
|
||||||
|
# Once defined, these can be passed in proxy_config.yaml
|
||||||
|
class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#### ASYNC ####
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
pass
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
|
||||||
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||||
|
data["model"] = "my-new-model"
|
||||||
|
return data
|
||||||
|
|
||||||
|
proxy_handler_instance = MyCustomHandler()
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Add this file to your proxy config
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Start the server + test the request
|
||||||
|
|
||||||
|
```shell
|
||||||
|
$ litellm /path/to/config.yaml
|
||||||
|
```
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||||
|
--data ' {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "good morning good sir"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"user": "ishaan-app",
|
||||||
|
"temperature": 0.2
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
|
@ -1,4 +1,11 @@
|
||||||
# Deploying LiteLLM Proxy
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# 🐳 Docker, Deploying LiteLLM Proxy
|
||||||
|
|
||||||
|
## Dockerfile
|
||||||
|
|
||||||
|
You can find the Dockerfile to build litellm proxy [here](https://github.com/BerriAI/litellm/blob/main/Dockerfile)
|
||||||
|
|
||||||
## Quick Start Docker Image: Github Container Registry
|
## Quick Start Docker Image: Github Container Registry
|
||||||
|
|
||||||
|
@ -7,12 +14,12 @@ See the latest available ghcr docker image here:
|
||||||
https://github.com/berriai/litellm/pkgs/container/litellm
|
https://github.com/berriai/litellm/pkgs/container/litellm
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
docker pull ghcr.io/berriai/litellm:main-v1.10.1
|
docker pull ghcr.io/berriai/litellm:main-v1.12.3
|
||||||
```
|
```
|
||||||
|
|
||||||
### Run the Docker Image
|
### Run the Docker Image
|
||||||
```shell
|
```shell
|
||||||
docker run ghcr.io/berriai/litellm:main-v1.10.0
|
docker run ghcr.io/berriai/litellm:main-v1.12.3
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Run the Docker Image with LiteLLM CLI args
|
#### Run the Docker Image with LiteLLM CLI args
|
||||||
|
@ -21,12 +28,12 @@ See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
|
||||||
|
|
||||||
Here's how you can run the docker image and pass your config to `litellm`
|
Here's how you can run the docker image and pass your config to `litellm`
|
||||||
```shell
|
```shell
|
||||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
|
docker run ghcr.io/berriai/litellm:main-v1.12.3 --config your_config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
|
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
|
||||||
```shell
|
```shell
|
||||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
|
docker run ghcr.io/berriai/litellm:main-v1.12.3 --port 8002 --num_workers 8
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Run the Docker Image using docker compose
|
#### Run the Docker Image using docker compose
|
||||||
|
@ -42,6 +49,10 @@ Here's an example `docker-compose.yml` file
|
||||||
version: "3.9"
|
version: "3.9"
|
||||||
services:
|
services:
|
||||||
litellm:
|
litellm:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
args:
|
||||||
|
target: runtime
|
||||||
image: ghcr.io/berriai/litellm:main
|
image: ghcr.io/berriai/litellm:main
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000" # Map the container port to the host, change the host port if necessary
|
- "8000:8000" # Map the container port to the host, change the host port if necessary
|
||||||
|
@ -74,6 +85,26 @@ Your LiteLLM container should be running now on the defined port e.g. `8000`.
|
||||||
<iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
<iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||||
|
|
||||||
|
|
||||||
|
## Deploy on Google Cloud Run
|
||||||
|
**Click the button** to deploy to Google Cloud Run
|
||||||
|
|
||||||
|
[](https://deploy.cloud.run/?git_repo=https://github.com/BerriAI/litellm)
|
||||||
|
|
||||||
|
#### Testing your deployed proxy
|
||||||
|
**Assuming the required keys are set as Environment Variables**
|
||||||
|
|
||||||
|
https://litellm-7yjrj3ha2q-uc.a.run.app is our example proxy, substitute it with your deployed cloud run app
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl https://litellm-7yjrj3ha2q-uc.a.run.app/v1/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||||
|
"temperature": 0.7
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
## LiteLLM Proxy Performance
|
## LiteLLM Proxy Performance
|
||||||
|
|
||||||
LiteLLM proxy has been load tested to handle 1500 req/s.
|
LiteLLM proxy has been load tested to handle 1500 req/s.
|
||||||
|
|
244
docs/my-website/docs/proxy/embedding.md
Normal file
244
docs/my-website/docs/proxy/embedding.md
Normal file
|
@ -0,0 +1,244 @@
|
||||||
|
import Image from '@theme/IdealImage';
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
|
# Embeddings - `/embeddings`
|
||||||
|
|
||||||
|
See supported Embedding Providers & Models [here](https://docs.litellm.ai/docs/embedding/supported_embedding)
|
||||||
|
|
||||||
|
|
||||||
|
## Quick start
|
||||||
|
Here's how to route between GPT-J embedding (sagemaker endpoint), Amazon Titan embedding (Bedrock) and Azure OpenAI embedding on the proxy server:
|
||||||
|
|
||||||
|
1. Set models in your config.yaml
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: sagemaker-embeddings
|
||||||
|
litellm_params:
|
||||||
|
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
||||||
|
- model_name: amazon-embeddings
|
||||||
|
litellm_params:
|
||||||
|
model: "bedrock/amazon.titan-embed-text-v1"
|
||||||
|
- model_name: azure-embeddings
|
||||||
|
litellm_params:
|
||||||
|
model: "azure/azure-embedding-model"
|
||||||
|
api_base: "os.environ/AZURE_API_BASE" # os.getenv("AZURE_API_BASE")
|
||||||
|
api_key: "os.environ/AZURE_API_KEY" # os.getenv("AZURE_API_KEY")
|
||||||
|
api_version: "2023-07-01-preview"
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # [OPTIONAL] if set all calls to proxy will require either this key or a valid generated token
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start the proxy
|
||||||
|
```shell
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Test the embedding call
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/v1/embeddings' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"input": "The food was delicious and the waiter..",
|
||||||
|
"model": "sagemaker-embeddings",
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
## `/embeddings` Request Format
|
||||||
|
Input, Output and Exceptions are mapped to the OpenAI format for all supported models
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="Curl" label="Curl Request">
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="openai" label="OpenAI v1.0.0+">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import openai
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# set base_url to your proxy server
|
||||||
|
# set api_key to send to proxy server
|
||||||
|
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
||||||
|
|
||||||
|
response = openai.embeddings.create(
|
||||||
|
input=["hello from litellm"],
|
||||||
|
model="text-embedding-ada-002"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="langchain-embedding" label="Langchain Embeddings">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.embeddings import OpenAIEmbeddings
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(model="sagemaker-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||||
|
|
||||||
|
|
||||||
|
text = "This is a test document."
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
print(f"SAGEMAKER EMBEDDINGS")
|
||||||
|
print(query_result[:5])
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(model="bedrock-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||||
|
|
||||||
|
text = "This is a test document."
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
print(f"BEDROCK EMBEDDINGS")
|
||||||
|
print(query_result[:5])
|
||||||
|
|
||||||
|
embeddings = OpenAIEmbeddings(model="bedrock-titan-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||||
|
|
||||||
|
text = "This is a test document."
|
||||||
|
|
||||||
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
print(f"TITAN EMBEDDINGS")
|
||||||
|
print(query_result[:5])
|
||||||
|
```
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## `/embeddings` Response Format
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"embedding": [
|
||||||
|
0.0023064255,
|
||||||
|
-0.009327292,
|
||||||
|
....
|
||||||
|
-0.0028842222,
|
||||||
|
],
|
||||||
|
"index": 0
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 8,
|
||||||
|
"total_tokens": 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Supported Models
|
||||||
|
|
||||||
|
See supported Embedding Providers & Models [here](https://docs.litellm.ai/docs/embedding/supported_embedding)
|
||||||
|
|
||||||
|
#### Create Config.yaml
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="Hugging Face emb" label="Hugging Face Embeddings">
|
||||||
|
LiteLLM Proxy supports all <a href="https://huggingface.co/models?pipeline_tag=feature-extraction">Feature-Extraction Embedding models</a>.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: deployed-codebert-base
|
||||||
|
litellm_params:
|
||||||
|
# send request to deployed hugging face inference endpoint
|
||||||
|
model: huggingface/microsoft/codebert-base # add huggingface prefix so it routes to hugging face
|
||||||
|
api_key: hf_LdS # api key for hugging face inference endpoint
|
||||||
|
api_base: https://uysneno1wv2wd4lw.us-east-1.aws.endpoints.huggingface.cloud # your hf inference endpoint
|
||||||
|
- model_name: codebert-base
|
||||||
|
litellm_params:
|
||||||
|
# no api_base set, sends request to hugging face free inference api https://api-inference.huggingface.co/models/
|
||||||
|
model: huggingface/microsoft/codebert-base # add huggingface prefix so it routes to hugging face
|
||||||
|
api_key: hf_LdS # api key for hugging face
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="azure" label="Azure OpenAI Embeddings">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: azure-embedding-model # model group
|
||||||
|
litellm_params:
|
||||||
|
model: azure/azure-embedding-model # model name for litellm.embedding(model=azure/azure-embedding-model) call
|
||||||
|
api_base: your-azure-api-base
|
||||||
|
api_key: your-api-key
|
||||||
|
api_version: 2023-07-01-preview
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="openai" label="OpenAI Embeddings">
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: text-embedding-ada-002 # model group
|
||||||
|
litellm_params:
|
||||||
|
model: text-embedding-ada-002 # model name for litellm.embedding(model=text-embedding-ada-002)
|
||||||
|
api_key: your-api-key-1
|
||||||
|
- model_name: text-embedding-ada-002
|
||||||
|
litellm_params:
|
||||||
|
model: text-embedding-ada-002
|
||||||
|
api_key: your-api-key-2
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="openai emb" label="OpenAI Compatible Embeddings">
|
||||||
|
|
||||||
|
<p>Use this for calling <a href="https://github.com/xorbitsai/inference">/embedding endpoints on OpenAI Compatible Servers</a>.</p>
|
||||||
|
|
||||||
|
**Note add `openai/` prefix to `litellm_params`: `model` so litellm knows to route to OpenAI**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: text-embedding-ada-002 # model group
|
||||||
|
litellm_params:
|
||||||
|
model: openai/<your-model-name> # model name for litellm.embedding(model=text-embedding-ada-002)
|
||||||
|
api_base: <model-api-base>
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
#### Start Proxy
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Make Request
|
||||||
|
Sends Request to `deployed-codebert-base`
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "deployed-codebert-base",
|
||||||
|
"input": ["write a litellm poem"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
62
docs/my-website/docs/proxy/health.md
Normal file
62
docs/my-website/docs/proxy/health.md
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Health Checks
|
||||||
|
Use this to health check all LLMs defined in your config.yaml
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
The proxy exposes:
|
||||||
|
* a /health endpoint which returns the health of the LLM APIs
|
||||||
|
* a /test endpoint which makes a ping to the litellm server
|
||||||
|
|
||||||
|
#### Request
|
||||||
|
Make a GET Request to `/health` on the proxy
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/health'
|
||||||
|
```
|
||||||
|
|
||||||
|
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
|
||||||
|
```
|
||||||
|
litellm --health
|
||||||
|
```
|
||||||
|
#### Response
|
||||||
|
```shell
|
||||||
|
{
|
||||||
|
"healthy_endpoints": [
|
||||||
|
{
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"unhealthy_endpoints": [
|
||||||
|
{
|
||||||
|
"model": "azure/gpt-35-turbo",
|
||||||
|
"api_base": "https://openai-france-1234.openai.azure.com/"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Background Health Checks
|
||||||
|
|
||||||
|
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
|
||||||
|
|
||||||
|
Here's how to use it:
|
||||||
|
1. in the config.yaml add:
|
||||||
|
```
|
||||||
|
general_settings:
|
||||||
|
background_health_checks: True # enable background health checks
|
||||||
|
health_check_interval: 300 # frequency of background health checks
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start server
|
||||||
|
```
|
||||||
|
$ litellm /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Query health endpoint:
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:8000/health'
|
||||||
|
```
|
|
@ -72,128 +72,28 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||||
'
|
'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Router settings on config - routing_strategy, model_group_alias
|
||||||
|
|
||||||
## Fallbacks + Cooldowns + Retries + Timeouts
|
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
|
||||||
|
|
||||||
If a call fails after num_retries, fall back to another model group.
|
Example config with `router_settings`
|
||||||
|
|
||||||
If the error is a context window exceeded error, fall back to a larger model group (if given).
|
|
||||||
|
|
||||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
|
||||||
|
|
||||||
**Set via config**
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: zephyr-beta
|
|
||||||
litellm_params:
|
|
||||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
|
||||||
api_base: http://0.0.0.0:8001
|
|
||||||
- model_name: zephyr-beta
|
|
||||||
litellm_params:
|
|
||||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
|
||||||
api_base: http://0.0.0.0:8002
|
|
||||||
- model_name: zephyr-beta
|
|
||||||
litellm_params:
|
|
||||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
|
||||||
api_base: http://0.0.0.0:8003
|
|
||||||
- model_name: gpt-3.5-turbo
|
|
||||||
litellm_params:
|
|
||||||
model: gpt-3.5-turbo
|
|
||||||
api_key: <my-openai-key>
|
|
||||||
- model_name: gpt-3.5-turbo-16k
|
|
||||||
litellm_params:
|
|
||||||
model: gpt-3.5-turbo-16k
|
|
||||||
api_key: <my-openai-key>
|
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
|
|
||||||
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
|
|
||||||
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
|
|
||||||
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
|
|
||||||
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
|
||||||
```
|
|
||||||
|
|
||||||
**Set dynamically**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
|
||||||
--header 'Content-Type: application/json' \
|
|
||||||
--data ' {
|
|
||||||
"model": "zephyr-beta",
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "what llm are you"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
|
||||||
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
|
||||||
"num_retries": 2,
|
|
||||||
"timeout": 10
|
|
||||||
}
|
|
||||||
'
|
|
||||||
```
|
|
||||||
|
|
||||||
## Custom Timeouts, Stream Timeouts - Per Model
|
|
||||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
|
||||||
```yaml
|
```yaml
|
||||||
model_list:
|
model_list:
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/gpt-turbo-small-eu
|
model: azure/<your-deployment-name>
|
||||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
api_base: <your-azure-endpoint>
|
||||||
api_key: <your-key>
|
api_key: <your-azure-api-key>
|
||||||
timeout: 0.1 # timeout in (seconds)
|
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
|
||||||
max_retries: 5
|
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/gpt-turbo-small-ca
|
model: azure/gpt-turbo-small-ca
|
||||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||||
api_key:
|
api_key: <your-azure-api-key>
|
||||||
timeout: 0.1 # timeout in (seconds)
|
rpm: 6
|
||||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
router_settings:
|
||||||
max_retries: 5
|
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
|
||||||
|
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]
|
||||||
```
|
num_retries: 2
|
||||||
|
timeout: 30 # 30 seconds
|
||||||
#### Start Proxy
|
|
||||||
```shell
|
|
||||||
$ litellm --config /path/to/config.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Health Check LLMs on Proxy
|
|
||||||
Use this to health check all LLMs defined in your config.yaml
|
|
||||||
#### Request
|
|
||||||
Make a GET Request to `/health` on the proxy
|
|
||||||
```shell
|
|
||||||
curl --location 'http://0.0.0.0:8000/health'
|
|
||||||
```
|
|
||||||
|
|
||||||
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
|
|
||||||
```
|
|
||||||
litellm --health
|
|
||||||
```
|
|
||||||
#### Response
|
|
||||||
```shell
|
|
||||||
{
|
|
||||||
"healthy_endpoints": [
|
|
||||||
{
|
|
||||||
"model": "azure/gpt-35-turbo",
|
|
||||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model": "azure/gpt-35-turbo",
|
|
||||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"unhealthy_endpoints": [
|
|
||||||
{
|
|
||||||
"model": "azure/gpt-35-turbo",
|
|
||||||
"api_base": "https://openai-france-1234.openai.azure.com/"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
```
|
|
@ -1,5 +1,8 @@
|
||||||
# Logging - Custom Callbacks, OpenTelemetry, Langfuse
|
import Image from '@theme/IdealImage';
|
||||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry
|
|
||||||
|
# Logging - Custom Callbacks, OpenTelemetry, Langfuse, Sentry
|
||||||
|
|
||||||
|
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB
|
||||||
|
|
||||||
## Custom Callback Class [Async]
|
## Custom Callback Class [Async]
|
||||||
Use this when you want to run custom callbacks in `python`
|
Use this when you want to run custom callbacks in `python`
|
||||||
|
@ -486,3 +489,166 @@ litellm --test
|
||||||
Expected output on Langfuse
|
Expected output on Langfuse
|
||||||
|
|
||||||
<Image img={require('../../img/langfuse_small.png')} />
|
<Image img={require('../../img/langfuse_small.png')} />
|
||||||
|
|
||||||
|
## Logging Proxy Input/Output - DynamoDB
|
||||||
|
|
||||||
|
We will use the `--config` to set
|
||||||
|
- `litellm.success_callback = ["dynamodb"]`
|
||||||
|
- `litellm.dynamodb_table_name = "your-table-name"`
|
||||||
|
|
||||||
|
This will log all successfull LLM calls to DynamoDB
|
||||||
|
|
||||||
|
**Step 1** Set AWS Credentials in .env
|
||||||
|
|
||||||
|
```shell
|
||||||
|
AWS_ACCESS_KEY_ID = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY = ""
|
||||||
|
AWS_REGION_NAME = ""
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
litellm_settings:
|
||||||
|
success_callback: ["dynamodb"]
|
||||||
|
dynamodb_table_name: your-table-name
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Start the proxy, make a test request
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Request
|
||||||
|
```shell
|
||||||
|
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "Azure OpenAI GPT-4 East",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
Your logs should be available on DynamoDB
|
||||||
|
|
||||||
|
#### Data Logged to DynamoDB /chat/completions
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": {
|
||||||
|
"S": "chatcmpl-8W15J4480a3fAQ1yQaMgtsKJAicen"
|
||||||
|
},
|
||||||
|
"call_type": {
|
||||||
|
"S": "acompletion"
|
||||||
|
},
|
||||||
|
"endTime": {
|
||||||
|
"S": "2023-12-15 17:25:58.424118"
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"S": "[{'role': 'user', 'content': 'This is a test'}]"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"S": "{}"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"S": "gpt-3.5-turbo"
|
||||||
|
},
|
||||||
|
"modelParameters": {
|
||||||
|
"S": "{'temperature': 0.7, 'max_tokens': 100, 'user': 'ishaan-2'}"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"S": "ModelResponse(id='chatcmpl-8W15J4480a3fAQ1yQaMgtsKJAicen', choices=[Choices(finish_reason='stop', index=0, message=Message(content='Great! What can I assist you with?', role='assistant'))], created=1702641357, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=9, prompt_tokens=11, total_tokens=20))"
|
||||||
|
},
|
||||||
|
"startTime": {
|
||||||
|
"S": "2023-12-15 17:25:56.047035"
|
||||||
|
},
|
||||||
|
"usage": {
|
||||||
|
"S": "Usage(completion_tokens=9, prompt_tokens=11, total_tokens=20)"
|
||||||
|
},
|
||||||
|
"user": {
|
||||||
|
"S": "ishaan-2"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Data logged to DynamoDB /embeddings
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"id": {
|
||||||
|
"S": "4dec8d4d-4817-472d-9fc6-c7a6153eb2ca"
|
||||||
|
},
|
||||||
|
"call_type": {
|
||||||
|
"S": "aembedding"
|
||||||
|
},
|
||||||
|
"endTime": {
|
||||||
|
"S": "2023-12-15 17:25:59.890261"
|
||||||
|
},
|
||||||
|
"messages": {
|
||||||
|
"S": "['hi']"
|
||||||
|
},
|
||||||
|
"metadata": {
|
||||||
|
"S": "{}"
|
||||||
|
},
|
||||||
|
"model": {
|
||||||
|
"S": "text-embedding-ada-002"
|
||||||
|
},
|
||||||
|
"modelParameters": {
|
||||||
|
"S": "{'user': 'ishaan-2'}"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"S": "EmbeddingResponse(model='text-embedding-ada-002-v2', data=[{'embedding': [-0.03503197431564331, -0.020601635798811913, -0.015375726856291294,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## Logging Proxy Input/Output - Sentry
|
||||||
|
|
||||||
|
If api calls fail (llm/database) you can log those to Sentry:
|
||||||
|
|
||||||
|
**Step 1** Install Sentry
|
||||||
|
```shell
|
||||||
|
pip install --upgrade sentry-sdk
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 2**: Save your Sentry_DSN and add `litellm_settings`: `failure_callback`
|
||||||
|
```shell
|
||||||
|
export SENTRY_DSN="your-sentry-dsn"
|
||||||
|
```
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
litellm_settings:
|
||||||
|
# other settings
|
||||||
|
failure_callback: ["sentry"]
|
||||||
|
general_settings:
|
||||||
|
database_url: "my-bad-url" # set a fake url to trigger a sentry exception
|
||||||
|
```
|
||||||
|
|
||||||
|
**Step 3**: Start the proxy, make a test request
|
||||||
|
|
||||||
|
Start proxy
|
||||||
|
```shell
|
||||||
|
litellm --config config.yaml --debug
|
||||||
|
```
|
||||||
|
|
||||||
|
Test Request
|
||||||
|
```
|
||||||
|
litellm --test
|
||||||
|
```
|
||||||
|
|
89
docs/my-website/docs/proxy/reliability.md
Normal file
89
docs/my-website/docs/proxy/reliability.md
Normal file
|
@ -0,0 +1,89 @@
|
||||||
|
# Fallbacks, Retries, Timeouts, Cooldowns
|
||||||
|
|
||||||
|
If a call fails after num_retries, fall back to another model group.
|
||||||
|
|
||||||
|
If the error is a context window exceeded error, fall back to a larger model group (if given).
|
||||||
|
|
||||||
|
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
||||||
|
|
||||||
|
**Set via config**
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: zephyr-beta
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||||
|
api_base: http://0.0.0.0:8001
|
||||||
|
- model_name: zephyr-beta
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||||
|
api_base: http://0.0.0.0:8002
|
||||||
|
- model_name: zephyr-beta
|
||||||
|
litellm_params:
|
||||||
|
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||||
|
api_base: http://0.0.0.0:8003
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
api_key: <my-openai-key>
|
||||||
|
- model_name: gpt-3.5-turbo-16k
|
||||||
|
litellm_params:
|
||||||
|
model: gpt-3.5-turbo-16k
|
||||||
|
api_key: <my-openai-key>
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
|
||||||
|
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
|
||||||
|
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
|
||||||
|
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
|
||||||
|
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
||||||
|
```
|
||||||
|
|
||||||
|
**Set dynamically**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data ' {
|
||||||
|
"model": "zephyr-beta",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||||
|
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||||
|
"num_retries": 2,
|
||||||
|
"timeout": 10
|
||||||
|
}
|
||||||
|
'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Custom Timeouts, Stream Timeouts - Per Model
|
||||||
|
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-turbo-small-eu
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: <your-key>
|
||||||
|
timeout: 0.1 # timeout in (seconds)
|
||||||
|
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||||
|
max_retries: 5
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-turbo-small-ca
|
||||||
|
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||||
|
api_key:
|
||||||
|
timeout: 0.1 # timeout in (seconds)
|
||||||
|
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||||
|
max_retries: 5
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Start Proxy
|
||||||
|
```shell
|
||||||
|
$ litellm --config /path/to/config.yaml
|
||||||
|
```
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
import Tabs from '@theme/Tabs';
|
import Tabs from '@theme/Tabs';
|
||||||
import TabItem from '@theme/TabItem';
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy.md)] Local OpenAI Proxy Server
|
# [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy)] Local OpenAI Proxy Server
|
||||||
|
|
||||||
A fast, and lightweight OpenAI-compatible server to call 100+ LLM APIs.
|
A fast, and lightweight OpenAI-compatible server to call 100+ LLM APIs.
|
||||||
|
|
||||||
:::info
|
:::info
|
||||||
|
|
||||||
Docs outdated. New docs 👉 [here](./simple_proxy.md)
|
Docs outdated. New docs 👉 [here](./simple_proxy)
|
||||||
|
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
|
|
@ -366,6 +366,63 @@ router = Router(model_list: Optional[list] = None,
|
||||||
cache_responses=True)
|
cache_responses=True)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Caching across model groups
|
||||||
|
|
||||||
|
If you want to cache across 2 different model groups (e.g. azure deployments, and openai), use caching groups.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm, asyncio, time
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
# set os env
|
||||||
|
os.environ["OPENAI_API_KEY"] = ""
|
||||||
|
os.environ["AZURE_API_KEY"] = ""
|
||||||
|
os.environ["AZURE_API_BASE"] = ""
|
||||||
|
os.environ["AZURE_API_VERSION"] = ""
|
||||||
|
|
||||||
|
async def test_acompletion_caching_on_router_caching_groups():
|
||||||
|
# tests acompletion + caching on router
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
start_time = time.time()
|
||||||
|
router = Router(model_list=model_list,
|
||||||
|
cache_responses=True,
|
||||||
|
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
|
||||||
|
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||||
|
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
assert response1.id == response2.id
|
||||||
|
assert len(response1.choices[0].message.content) > 0
|
||||||
|
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||||
|
```
|
||||||
|
|
||||||
#### Default litellm.completion/embedding params
|
#### Default litellm.completion/embedding params
|
||||||
|
|
||||||
You can also set default params for litellm completion/embedding calls. Here's how to do that:
|
You can also set default params for litellm completion/embedding calls. Here's how to do that:
|
||||||
|
@ -391,200 +448,3 @@ print(f"response: {response}")
|
||||||
## Deploy Router
|
## Deploy Router
|
||||||
|
|
||||||
If you want a server to load balance across different LLM APIs, use our [OpenAI Proxy Server](./simple_proxy#load-balancing---multiple-instances-of-1-model)
|
If you want a server to load balance across different LLM APIs, use our [OpenAI Proxy Server](./simple_proxy#load-balancing---multiple-instances-of-1-model)
|
||||||
|
|
||||||
## Queuing (Beta)
|
|
||||||
|
|
||||||
**Never fail a request due to rate limits**
|
|
||||||
|
|
||||||
The LiteLLM Queuing endpoints can handle 100+ req/s. We use Celery workers to process requests.
|
|
||||||
|
|
||||||
:::info
|
|
||||||
|
|
||||||
This is pretty new, and might have bugs. Any contributions to improving our implementation are welcome
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
[**See Code**](https://github.com/BerriAI/litellm/blob/fbf9cab5b9e35df524e2c9953180c58d92e4cd97/litellm/proxy/proxy_server.py#L589)
|
|
||||||
|
|
||||||
|
|
||||||
### Quick Start
|
|
||||||
|
|
||||||
1. Add Redis credentials in a .env file
|
|
||||||
|
|
||||||
```python
|
|
||||||
REDIS_HOST="my-redis-endpoint"
|
|
||||||
REDIS_PORT="my-redis-port"
|
|
||||||
REDIS_PASSWORD="my-redis-password" # [OPTIONAL] if self-hosted
|
|
||||||
REDIS_USERNAME="default" # [OPTIONAL] if self-hosted
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Start litellm server with your model config
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ litellm --config /path/to/config.yaml --use_queue
|
|
||||||
```
|
|
||||||
|
|
||||||
Here's an example config for `gpt-3.5-turbo`
|
|
||||||
|
|
||||||
**config.yaml** (This will load balance between OpenAI + Azure endpoints)
|
|
||||||
```yaml
|
|
||||||
model_list:
|
|
||||||
- model_name: gpt-3.5-turbo
|
|
||||||
litellm_params:
|
|
||||||
model: gpt-3.5-turbo
|
|
||||||
api_key:
|
|
||||||
- model_name: gpt-3.5-turbo
|
|
||||||
litellm_params:
|
|
||||||
model: azure/chatgpt-v-2 # actual model name
|
|
||||||
api_key:
|
|
||||||
api_version: 2023-07-01-preview
|
|
||||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
|
||||||
```
|
|
||||||
|
|
||||||
3. Test (in another window) → sends 100 simultaneous requests to the queue
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ litellm --test_async --num_requests 100
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### Available Endpoints
|
|
||||||
- `/queue/request` - Queues a /chat/completions request. Returns a job id.
|
|
||||||
- `/queue/response/{id}` - Returns the status of a job. If completed, returns the response as well. Potential status's are: `queued` and `finished`.
|
|
||||||
|
|
||||||
|
|
||||||
## Hosted Request Queing api.litellm.ai
|
|
||||||
Queue your LLM API requests to ensure you're under your rate limits
|
|
||||||
- Step 1: Step 1 Add a config to the proxy, generate a temp key
|
|
||||||
- Step 2: Queue a request to the proxy, using your generated_key
|
|
||||||
- Step 3: Poll the request
|
|
||||||
|
|
||||||
|
|
||||||
### Step 1 Add a config to the proxy, generate a temp key
|
|
||||||
```python
|
|
||||||
import requests
|
|
||||||
import time
|
|
||||||
import os
|
|
||||||
|
|
||||||
# Set the base URL as needed
|
|
||||||
base_url = "https://api.litellm.ai"
|
|
||||||
|
|
||||||
# Step 1 Add a config to the proxy, generate a temp key
|
|
||||||
# use the same model_name to load balance
|
|
||||||
config = {
|
|
||||||
"model_list": [
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"api_key": os.environ['OPENAI_API_KEY'],
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": "",
|
|
||||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
|
||||||
"api_version": "2023-07-01-preview"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
url=f"{base_url}/key/generate",
|
|
||||||
json={
|
|
||||||
"config": config,
|
|
||||||
"duration": "30d" # default to 30d, set it to 30m if you want a temp 30 minute key
|
|
||||||
},
|
|
||||||
headers={
|
|
||||||
"Authorization": "Bearer sk-hosted-litellm" # this is the key to use api.litellm.ai
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\nresponse from generating key", response.text)
|
|
||||||
print("\n json response from gen key", response.json())
|
|
||||||
|
|
||||||
generated_key = response.json()["key"]
|
|
||||||
print("\ngenerated key for proxy", generated_key)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Output
|
|
||||||
```shell
|
|
||||||
response from generating key {"key":"sk-...,"expires":"2023-12-22T03:43:57.615000+00:00"}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 2: Queue a request to the proxy, using your generated_key
|
|
||||||
```python
|
|
||||||
print("Creating a job on the proxy")
|
|
||||||
job_response = requests.post(
|
|
||||||
url=f"{base_url}/queue/request",
|
|
||||||
json={
|
|
||||||
'model': 'gpt-3.5-turbo',
|
|
||||||
'messages': [
|
|
||||||
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'},
|
|
||||||
],
|
|
||||||
},
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {generated_key}"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
print(job_response.status_code)
|
|
||||||
print(job_response.text)
|
|
||||||
print("\nResponse from creating job", job_response.text)
|
|
||||||
job_response = job_response.json()
|
|
||||||
job_id = job_response["id"]
|
|
||||||
polling_url = job_response["url"]
|
|
||||||
polling_url = f"{base_url}{polling_url}"
|
|
||||||
print("\nCreated Job, Polling Url", polling_url)
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Output
|
|
||||||
```shell
|
|
||||||
Response from creating job
|
|
||||||
{"id":"0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7","url":"/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7","eta":5,"status":"queued"}
|
|
||||||
```
|
|
||||||
|
|
||||||
### Step 3: Poll the request
|
|
||||||
```python
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
print("\nPolling URL", polling_url)
|
|
||||||
polling_response = requests.get(
|
|
||||||
url=polling_url,
|
|
||||||
headers={
|
|
||||||
"Authorization": f"Bearer {generated_key}"
|
|
||||||
}
|
|
||||||
)
|
|
||||||
print("\nResponse from polling url", polling_response.text)
|
|
||||||
polling_response = polling_response.json()
|
|
||||||
status = polling_response.get("status", None)
|
|
||||||
if status == "finished":
|
|
||||||
llm_response = polling_response["result"]
|
|
||||||
print("LLM Response")
|
|
||||||
print(llm_response)
|
|
||||||
break
|
|
||||||
time.sleep(0.5)
|
|
||||||
except Exception as e:
|
|
||||||
print("got exception in polling", e)
|
|
||||||
break
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Output
|
|
||||||
```shell
|
|
||||||
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
|
|
||||||
|
|
||||||
Response from polling url {"status":"queued"}
|
|
||||||
|
|
||||||
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
|
|
||||||
|
|
||||||
Response from polling url {"status":"queued"}
|
|
||||||
|
|
||||||
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
|
|
||||||
|
|
||||||
Response from polling url
|
|
||||||
{"status":"finished","result":{"id":"chatcmpl-8NYRce4IeI4NzYyodT3NNp8fk5cSW","choices":[{"finish_reason":"stop","index":0,"message":{"content":"I am an AI assistant and do not have a physical presence or personal identity. You can simply refer to me as \"Assistant.\" How may I assist you today?","role":"assistant"}}],"created":1700624639,"model":"gpt-3.5-turbo-0613","object":"chat.completion","system_fingerprint":null,"usage":{"completion_tokens":33,"prompt_tokens":17,"total_tokens":50}}}
|
|
||||||
|
|
||||||
```
|
|
|
@ -61,11 +61,13 @@ const sidebars = {
|
||||||
},
|
},
|
||||||
items: [
|
items: [
|
||||||
"providers/openai",
|
"providers/openai",
|
||||||
|
"providers/openai_compatible",
|
||||||
"providers/azure",
|
"providers/azure",
|
||||||
"providers/huggingface",
|
"providers/huggingface",
|
||||||
"providers/ollama",
|
"providers/ollama",
|
||||||
"providers/vertex",
|
"providers/vertex",
|
||||||
"providers/palm",
|
"providers/palm",
|
||||||
|
"providers/mistral",
|
||||||
"providers/anthropic",
|
"providers/anthropic",
|
||||||
"providers/aws_sagemaker",
|
"providers/aws_sagemaker",
|
||||||
"providers/bedrock",
|
"providers/bedrock",
|
||||||
|
@ -97,9 +99,13 @@ const sidebars = {
|
||||||
items: [
|
items: [
|
||||||
"proxy/quick_start",
|
"proxy/quick_start",
|
||||||
"proxy/configs",
|
"proxy/configs",
|
||||||
|
"proxy/embedding",
|
||||||
"proxy/load_balancing",
|
"proxy/load_balancing",
|
||||||
"proxy/virtual_keys",
|
"proxy/virtual_keys",
|
||||||
"proxy/model_management",
|
"proxy/model_management",
|
||||||
|
"proxy/reliability",
|
||||||
|
"proxy/health",
|
||||||
|
"proxy/call_hooks",
|
||||||
"proxy/caching",
|
"proxy/caching",
|
||||||
"proxy/logging",
|
"proxy/logging",
|
||||||
"proxy/cli",
|
"proxy/cli",
|
||||||
|
@ -189,6 +195,7 @@ const sidebars = {
|
||||||
slug: '/project',
|
slug: '/project',
|
||||||
},
|
},
|
||||||
items: [
|
items: [
|
||||||
|
"projects/Docq.AI",
|
||||||
"projects/OpenInterpreter",
|
"projects/OpenInterpreter",
|
||||||
"projects/FastREPL",
|
"projects/FastREPL",
|
||||||
"projects/PROMPTMETHEUS",
|
"projects/PROMPTMETHEUS",
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -10,7 +10,7 @@ success_callback: List[Union[str, Callable]] = []
|
||||||
failure_callback: List[Union[str, Callable]] = []
|
failure_callback: List[Union[str, Callable]] = []
|
||||||
callbacks: List[Callable] = []
|
callbacks: List[Callable] = []
|
||||||
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
|
||||||
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||||
pre_call_rules: List[Callable] = []
|
pre_call_rules: List[Callable] = []
|
||||||
post_call_rules: List[Callable] = []
|
post_call_rules: List[Callable] = []
|
||||||
|
@ -48,6 +48,8 @@ cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.
|
||||||
model_alias_map: Dict[str, str] = {}
|
model_alias_map: Dict[str, str] = {}
|
||||||
model_group_alias_map: Dict[str, str] = {}
|
model_group_alias_map: Dict[str, str] = {}
|
||||||
max_budget: float = 0.0 # set the max budget across all providers
|
max_budget: float = 0.0 # set the max budget across all providers
|
||||||
|
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||||
|
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||||
_current_cost = 0 # private variable, used if max budget is set
|
_current_cost = 0 # private variable, used if max budget is set
|
||||||
error_logs: Dict = {}
|
error_logs: Dict = {}
|
||||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||||
|
@ -56,6 +58,7 @@ aclient_session: Optional[httpx.AsyncClient] = None
|
||||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
suppress_debug_info = False
|
suppress_debug_info = False
|
||||||
|
dynamodb_table_name: Optional[str] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
request_timeout: Optional[float] = 6000
|
request_timeout: Optional[float] = 6000
|
||||||
num_retries: Optional[int] = None
|
num_retries: Optional[int] = None
|
||||||
|
@ -107,6 +110,8 @@ open_ai_text_completion_models: List = []
|
||||||
cohere_models: List = []
|
cohere_models: List = []
|
||||||
anthropic_models: List = []
|
anthropic_models: List = []
|
||||||
openrouter_models: List = []
|
openrouter_models: List = []
|
||||||
|
vertex_language_models: List = []
|
||||||
|
vertex_vision_models: List = []
|
||||||
vertex_chat_models: List = []
|
vertex_chat_models: List = []
|
||||||
vertex_code_chat_models: List = []
|
vertex_code_chat_models: List = []
|
||||||
vertex_text_models: List = []
|
vertex_text_models: List = []
|
||||||
|
@ -133,6 +138,10 @@ for key, value in model_cost.items():
|
||||||
vertex_text_models.append(key)
|
vertex_text_models.append(key)
|
||||||
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||||
vertex_code_text_models.append(key)
|
vertex_code_text_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'vertex_ai-language-models':
|
||||||
|
vertex_language_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'vertex_ai-vision-models':
|
||||||
|
vertex_vision_models.append(key)
|
||||||
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
|
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
|
||||||
vertex_chat_models.append(key)
|
vertex_chat_models.append(key)
|
||||||
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
|
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
|
||||||
|
@ -154,7 +163,16 @@ for key, value in model_cost.items():
|
||||||
openai_compatible_endpoints: List = [
|
openai_compatible_endpoints: List = [
|
||||||
"api.perplexity.ai",
|
"api.perplexity.ai",
|
||||||
"api.endpoints.anyscale.com/v1",
|
"api.endpoints.anyscale.com/v1",
|
||||||
"api.deepinfra.com/v1/openai"
|
"api.deepinfra.com/v1/openai",
|
||||||
|
"api.mistral.ai/v1"
|
||||||
|
]
|
||||||
|
|
||||||
|
# this is maintained for Exception Mapping
|
||||||
|
openai_compatible_providers: List = [
|
||||||
|
"anyscale",
|
||||||
|
"mistral",
|
||||||
|
"deepinfra",
|
||||||
|
"perplexity"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -266,6 +284,7 @@ model_list = (
|
||||||
provider_list: List = [
|
provider_list: List = [
|
||||||
"openai",
|
"openai",
|
||||||
"custom_openai",
|
"custom_openai",
|
||||||
|
"text-completion-openai",
|
||||||
"cohere",
|
"cohere",
|
||||||
"anthropic",
|
"anthropic",
|
||||||
"replicate",
|
"replicate",
|
||||||
|
@ -287,6 +306,7 @@ provider_list: List = [
|
||||||
"deepinfra",
|
"deepinfra",
|
||||||
"perplexity",
|
"perplexity",
|
||||||
"anyscale",
|
"anyscale",
|
||||||
|
"mistral",
|
||||||
"maritalk",
|
"maritalk",
|
||||||
"custom", # custom apis
|
"custom", # custom apis
|
||||||
]
|
]
|
||||||
|
@ -396,6 +416,7 @@ from .exceptions import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
InvalidRequestError,
|
InvalidRequestError,
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
|
NotFoundError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
ServiceUnavailableError,
|
ServiceUnavailableError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
|
@ -404,7 +425,8 @@ from .exceptions import (
|
||||||
APIError,
|
APIError,
|
||||||
Timeout,
|
Timeout,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIResponseValidationError
|
APIResponseValidationError,
|
||||||
|
UnprocessableEntityError
|
||||||
)
|
)
|
||||||
from .budget_manager import BudgetManager
|
from .budget_manager import BudgetManager
|
||||||
from .proxy.proxy_cli import run_server
|
from .proxy.proxy_cli import run_server
|
||||||
|
|
|
@ -10,19 +10,7 @@
|
||||||
import litellm
|
import litellm
|
||||||
import time, logging
|
import time, logging
|
||||||
import json, traceback, ast
|
import json, traceback, ast
|
||||||
from typing import Optional
|
from typing import Optional, Literal, List
|
||||||
|
|
||||||
def get_prompt(*args, **kwargs):
|
|
||||||
# make this safe checks, it should not throw any exceptions
|
|
||||||
if len(args) > 1:
|
|
||||||
messages = args[1]
|
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
|
||||||
return prompt
|
|
||||||
if "messages" in kwargs:
|
|
||||||
messages = kwargs["messages"]
|
|
||||||
prompt = " ".join(message["content"] for message in messages)
|
|
||||||
return prompt
|
|
||||||
return None
|
|
||||||
|
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
try:
|
try:
|
||||||
|
@ -174,34 +162,36 @@ class DualCache(BaseCache):
|
||||||
if self.redis_cache is not None:
|
if self.redis_cache is not None:
|
||||||
self.redis_cache.flush_cache()
|
self.redis_cache.flush_cache()
|
||||||
|
|
||||||
#### LiteLLM.Completion Cache ####
|
#### LiteLLM.Completion / Embedding Cache ####
|
||||||
class Cache:
|
class Cache:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
type="local",
|
type: Optional[Literal["local", "redis"]] = "local",
|
||||||
host=None,
|
host: Optional[str] = None,
|
||||||
port=None,
|
port: Optional[str] = None,
|
||||||
password=None,
|
password: Optional[str] = None,
|
||||||
|
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initializes the cache based on the given type.
|
Initializes the cache based on the given type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str, optional): The type of cache to initialize. Defaults to "local".
|
type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
||||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||||
|
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If an invalid cache type is provided.
|
ValueError: If an invalid cache type is provided.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None. Cache is set as a litellm param
|
||||||
"""
|
"""
|
||||||
if type == "redis":
|
if type == "redis":
|
||||||
self.cache = RedisCache(host, port, password, **kwargs)
|
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||||
if type == "local":
|
if type == "local":
|
||||||
self.cache = InMemoryCache()
|
self.cache = InMemoryCache()
|
||||||
if "cache" not in litellm.input_callback:
|
if "cache" not in litellm.input_callback:
|
||||||
|
@ -210,6 +200,7 @@ class Cache:
|
||||||
litellm.success_callback.append("cache")
|
litellm.success_callback.append("cache")
|
||||||
if "cache" not in litellm._async_success_callback:
|
if "cache" not in litellm._async_success_callback:
|
||||||
litellm._async_success_callback.append("cache")
|
litellm._async_success_callback.append("cache")
|
||||||
|
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||||
|
|
||||||
def get_cache_key(self, *args, **kwargs):
|
def get_cache_key(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -222,29 +213,55 @@ class Cache:
|
||||||
Returns:
|
Returns:
|
||||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||||
"""
|
"""
|
||||||
cache_key =""
|
cache_key = ""
|
||||||
|
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||||
|
|
||||||
|
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||||
|
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||||
|
print_verbose(f"\nReturning preset cache key: {cache_key}")
|
||||||
|
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
|
||||||
|
|
||||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||||
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
|
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
|
||||||
for param in completion_kwargs:
|
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||||
|
|
||||||
|
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
||||||
|
combined_kwargs = completion_kwargs + embedding_only_kwargs
|
||||||
|
for param in combined_kwargs:
|
||||||
# ignore litellm params here
|
# ignore litellm params here
|
||||||
if param in kwargs:
|
if param in kwargs:
|
||||||
# check if param == model and model_group is passed in, then override model with model_group
|
# check if param == model and model_group is passed in, then override model with model_group
|
||||||
if param == "model":
|
if param == "model":
|
||||||
model_group = None
|
model_group = None
|
||||||
|
caching_group = None
|
||||||
metadata = kwargs.get("metadata", None)
|
metadata = kwargs.get("metadata", None)
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
model_group = metadata.get("model_group")
|
model_group = metadata.get("model_group")
|
||||||
|
model_group = metadata.get("model_group", None)
|
||||||
|
caching_groups = metadata.get("caching_groups", None)
|
||||||
|
if caching_groups:
|
||||||
|
for group in caching_groups:
|
||||||
|
if model_group in group:
|
||||||
|
caching_group = group
|
||||||
|
break
|
||||||
if litellm_params is not None:
|
if litellm_params is not None:
|
||||||
metadata = litellm_params.get("metadata", None)
|
metadata = litellm_params.get("metadata", None)
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
model_group = metadata.get("model_group", None)
|
model_group = metadata.get("model_group", None)
|
||||||
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"]
|
caching_groups = metadata.get("caching_groups", None)
|
||||||
|
if caching_groups:
|
||||||
|
for group in caching_groups:
|
||||||
|
if model_group in group:
|
||||||
|
caching_group = group
|
||||||
|
break
|
||||||
|
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||||
else:
|
else:
|
||||||
if kwargs[param] is None:
|
if kwargs[param] is None:
|
||||||
continue # ignore None params
|
continue # ignore None params
|
||||||
param_value = kwargs[param]
|
param_value = kwargs[param]
|
||||||
cache_key+= f"{str(param)}: {str(param_value)}"
|
cache_key+= f"{str(param)}: {str(param_value)}"
|
||||||
|
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||||
return cache_key
|
return cache_key
|
||||||
|
|
||||||
def generate_streaming_content(self, content):
|
def generate_streaming_content(self, content):
|
||||||
|
@ -297,4 +314,9 @@ class Cache:
|
||||||
result = result.model_dump_json()
|
result = result.model_dump_json()
|
||||||
self.cache.set_cache(cache_key, result, **kwargs)
|
self.cache.set_cache(cache_key, result, **kwargs)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _async_add_cache(self, result, *args, **kwargs):
|
||||||
|
self.add_cache(result, *args, **kwargs)
|
|
@ -12,16 +12,19 @@
|
||||||
from openai import (
|
from openai import (
|
||||||
AuthenticationError,
|
AuthenticationError,
|
||||||
BadRequestError,
|
BadRequestError,
|
||||||
|
NotFoundError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
APIStatusError,
|
APIStatusError,
|
||||||
OpenAIError,
|
OpenAIError,
|
||||||
APIError,
|
APIError,
|
||||||
APITimeoutError,
|
APITimeoutError,
|
||||||
APIConnectionError,
|
APIConnectionError,
|
||||||
APIResponseValidationError
|
APIResponseValidationError,
|
||||||
|
UnprocessableEntityError
|
||||||
)
|
)
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
class AuthenticationError(AuthenticationError): # type: ignore
|
class AuthenticationError(AuthenticationError): # type: ignore
|
||||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||||
self.status_code = 401
|
self.status_code = 401
|
||||||
|
@ -34,6 +37,20 @@ class AuthenticationError(AuthenticationError): # type: ignore
|
||||||
body=None
|
body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
# raise when invalid models passed, example gpt-8
|
||||||
|
class NotFoundError(NotFoundError): # type: ignore
|
||||||
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
|
self.status_code = 404
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
super().__init__(
|
||||||
|
self.message,
|
||||||
|
response=response,
|
||||||
|
body=None
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class BadRequestError(BadRequestError): # type: ignore
|
class BadRequestError(BadRequestError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
self.status_code = 400
|
self.status_code = 400
|
||||||
|
@ -46,6 +63,18 @@ class BadRequestError(BadRequestError): # type: ignore
|
||||||
body=None
|
body=None
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||||
|
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||||
|
self.status_code = 422
|
||||||
|
self.message = message
|
||||||
|
self.model = model
|
||||||
|
self.llm_provider = llm_provider
|
||||||
|
super().__init__(
|
||||||
|
self.message,
|
||||||
|
response=response,
|
||||||
|
body=None
|
||||||
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
class Timeout(APITimeoutError): # type: ignore
|
class Timeout(APITimeoutError): # type: ignore
|
||||||
def __init__(self, message, model, llm_provider):
|
def __init__(self, message, model, llm_provider):
|
||||||
self.status_code = 408
|
self.status_code = 408
|
||||||
|
|
|
@ -2,8 +2,9 @@
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
import requests
|
||||||
import requests
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
|
from litellm.caching import DualCache
|
||||||
|
from typing import Literal
|
||||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -28,6 +29,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
||||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
#### ASYNC ####
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
pass
|
||||||
|
|
||||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -37,6 +43,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
#### CALL HOOKS - proxy only ####
|
||||||
|
"""
|
||||||
|
Control the modify incoming / outgoung data before calling the model
|
||||||
|
"""
|
||||||
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||||
|
pass
|
||||||
|
|
||||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||||
|
|
||||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||||
|
|
82
litellm/integrations/dynamodb.py
Normal file
82
litellm/integrations/dynamodb.py
Normal file
|
@ -0,0 +1,82 @@
|
||||||
|
#### What this does ####
|
||||||
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
|
import dotenv, os
|
||||||
|
import requests
|
||||||
|
|
||||||
|
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||||
|
import traceback
|
||||||
|
import datetime, subprocess, sys
|
||||||
|
import litellm, uuid
|
||||||
|
from litellm._logging import print_verbose
|
||||||
|
|
||||||
|
class DyanmoDBLogger:
|
||||||
|
# Class variables or attributes
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Instance variables
|
||||||
|
import boto3
|
||||||
|
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
|
||||||
|
if litellm.dynamodb_table_name is None:
|
||||||
|
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`")
|
||||||
|
self.table_name = litellm.dynamodb_table_name
|
||||||
|
|
||||||
|
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
|
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||||
|
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
|
try:
|
||||||
|
print_verbose(
|
||||||
|
f"DynamoDB Logging - Enters logging function for model {kwargs}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# construct payload to send to DynamoDB
|
||||||
|
# follows the same params as langfuse.py
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
||||||
|
messages = kwargs.get("messages")
|
||||||
|
optional_params = kwargs.get("optional_params", {})
|
||||||
|
call_type = kwargs.get("call_type", "litellm.completion")
|
||||||
|
usage = response_obj["usage"]
|
||||||
|
id = response_obj.get("id", str(uuid.uuid4()))
|
||||||
|
|
||||||
|
# Build the initial payload
|
||||||
|
payload = {
|
||||||
|
"id": id,
|
||||||
|
"call_type": call_type,
|
||||||
|
"startTime": start_time,
|
||||||
|
"endTime": end_time,
|
||||||
|
"model": kwargs.get("model", ""),
|
||||||
|
"user": kwargs.get("user", ""),
|
||||||
|
"modelParameters": optional_params,
|
||||||
|
"messages": messages,
|
||||||
|
"response": response_obj,
|
||||||
|
"usage": usage,
|
||||||
|
"metadata": metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure everything in the payload is converted to str
|
||||||
|
for key, value in payload.items():
|
||||||
|
try:
|
||||||
|
payload[key] = str(value)
|
||||||
|
except:
|
||||||
|
# non blocking if it can't cast to a str
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
|
||||||
|
|
||||||
|
# put data in dyanmo DB
|
||||||
|
table = self.dynamodb.Table(self.table_name)
|
||||||
|
# Assuming log_data is a dictionary with log information
|
||||||
|
response = table.put_item(Item=payload)
|
||||||
|
|
||||||
|
print_verbose(f"Response from DynamoDB:{str(response)}")
|
||||||
|
|
||||||
|
print_verbose(
|
||||||
|
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except:
|
||||||
|
traceback.print_exc()
|
||||||
|
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||||
|
pass
|
|
@ -58,7 +58,7 @@ class LangFuseLogger:
|
||||||
model=kwargs['model'],
|
model=kwargs['model'],
|
||||||
modelParameters=optional_params,
|
modelParameters=optional_params,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
completion=response_obj['choices'][0]['message'],
|
completion=response_obj['choices'][0]['message'].json(),
|
||||||
usage=Usage(
|
usage=Usage(
|
||||||
prompt_tokens=response_obj['usage']['prompt_tokens'],
|
prompt_tokens=response_obj['usage']['prompt_tokens'],
|
||||||
completion_tokens=response_obj['usage']['completion_tokens']
|
completion_tokens=response_obj['usage']['completion_tokens']
|
||||||
|
@ -70,6 +70,9 @@ class LangFuseLogger:
|
||||||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
# traceback.print_exc()
|
traceback.print_exc()
|
||||||
print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
|
print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
|
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||||
|
|
|
@ -58,7 +58,7 @@ class LangsmithLogger:
|
||||||
"inputs": {
|
"inputs": {
|
||||||
**new_kwargs
|
**new_kwargs
|
||||||
},
|
},
|
||||||
"outputs": response_obj,
|
"outputs": response_obj.json(),
|
||||||
"session_name": project_name,
|
"session_name": project_name,
|
||||||
"start_time": start_time,
|
"start_time": start_time,
|
||||||
"end_time": end_time,
|
"end_time": end_time,
|
||||||
|
|
|
@ -1,7 +1,8 @@
|
||||||
class TraceloopLogger:
|
class TraceloopLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||||
|
from traceloop.sdk import Traceloop
|
||||||
|
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
|
||||||
self.tracer_wrapper = TracerWrapper()
|
self.tracer_wrapper = TracerWrapper()
|
||||||
|
|
||||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||||
|
|
|
@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||||
response.model = "azure/" + str(response.model)
|
stringified_response = response.model_dump_json()
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=stringified_response,
|
||||||
|
additional_args={
|
||||||
|
"headers": headers,
|
||||||
|
"api_version": api_version,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
|
@ -318,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
azure_client_params: dict,
|
azure_client_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
input: list,
|
||||||
client=None,
|
client=None,
|
||||||
|
logging_obj=None
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -327,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.embeddings.create(**data)
|
response = await openai_aclient.embeddings.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
|
stringified_response = response.model_dump_json()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=stringified_response,
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
|
@ -372,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if aembedding == True:
|
|
||||||
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
|
|
||||||
return response
|
|
||||||
if client is None:
|
|
||||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
|
||||||
else:
|
|
||||||
azure_client = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -391,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
|
||||||
|
return response
|
||||||
|
if client is None:
|
||||||
|
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = azure_client.embeddings.create(**data) # type: ignore
|
response = azure_client.embeddings.create(**data) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -482,7 +482,7 @@ def completion(
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=prompt,
|
input=prompt,
|
||||||
api_key="",
|
api_key="",
|
||||||
original_response=response_body,
|
original_response=json.dumps(response_body),
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
print_verbose(f"raw model_response: {response}")
|
print_verbose(f"raw model_response: {response}")
|
||||||
|
@ -552,6 +552,7 @@ def _embedding_func_single(
|
||||||
## FORMAT EMBEDDING INPUT ##
|
## FORMAT EMBEDDING INPUT ##
|
||||||
provider = model.split(".")[0]
|
provider = model.split(".")[0]
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
|
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
|
||||||
if provider == "amazon":
|
if provider == "amazon":
|
||||||
input = input.replace(os.linesep, " ")
|
input = input.replace(os.linesep, " ")
|
||||||
data = {"inputText": input, **inference_params}
|
data = {"inputText": input, **inference_params}
|
||||||
|
@ -587,7 +588,7 @@ def _embedding_func_single(
|
||||||
input=input,
|
input=input,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
original_response=response_body,
|
original_response=json.dumps(response_body),
|
||||||
)
|
)
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
response = response_body.get("embeddings")
|
response = response_body.get("embeddings")
|
||||||
|
@ -651,13 +652,4 @@ def embedding(
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": {"model": model,
|
|
||||||
"texts": input}},
|
|
||||||
original_response=embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -542,7 +542,7 @@ class Huggingface(BaseLLM):
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
|
@ -584,6 +584,14 @@ class Huggingface(BaseLLM):
|
||||||
"embedding": embedding # flatten list returned from hf
|
"embedding": embedding # flatten list returned from hf
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
elif isinstance(embedding, list) and isinstance(embedding[0], float):
|
||||||
|
output_data.append(
|
||||||
|
{
|
||||||
|
"object": "embedding",
|
||||||
|
"index": idx,
|
||||||
|
"embedding": embedding # flatten list returned from hf
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
output_data.append(
|
output_data.append(
|
||||||
{
|
{
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
import requests, types
|
import requests, types, time
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx, aiohttp, asyncio
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from async_generator import async_generator, yield_ # optional dependency
|
from async_generator import async_generator, yield_ # optional dependency
|
||||||
async_generator_imported = True
|
async_generator_imported = True
|
||||||
|
@ -115,6 +114,9 @@ def get_ollama_response_stream(
|
||||||
prompt="Why is the sky blue?",
|
prompt="Why is the sky blue?",
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
logging_obj=None,
|
logging_obj=None,
|
||||||
|
acompletion: bool = False,
|
||||||
|
model_response=None,
|
||||||
|
encoding=None
|
||||||
):
|
):
|
||||||
if api_base.endswith("/api/generate"):
|
if api_base.endswith("/api/generate"):
|
||||||
url = api_base
|
url = api_base
|
||||||
|
@ -136,8 +138,19 @@ def get_ollama_response_stream(
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=None,
|
input=None,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
additional_args={"api_base": url, "complete_input_dict": data},
|
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
|
||||||
)
|
)
|
||||||
|
if acompletion is True:
|
||||||
|
if optional_params.get("stream", False):
|
||||||
|
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||||
|
else:
|
||||||
|
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||||
|
return response
|
||||||
|
|
||||||
|
else:
|
||||||
|
return ollama_completion_stream(url=url, data=data)
|
||||||
|
|
||||||
|
def ollama_completion_stream(url, data):
|
||||||
session = requests.Session()
|
session = requests.Session()
|
||||||
|
|
||||||
with session.post(url, json=data, stream=True) as resp:
|
with session.post(url, json=data, stream=True) as resp:
|
||||||
|
@ -169,41 +182,38 @@ def get_ollama_response_stream(
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
session.close()
|
session.close()
|
||||||
|
|
||||||
if async_generator_imported:
|
|
||||||
# ollama implementation
|
|
||||||
@async_generator
|
|
||||||
async def async_get_ollama_response_stream(
|
|
||||||
api_base="http://localhost:11434",
|
|
||||||
model="llama2",
|
|
||||||
prompt="Why is the sky blue?",
|
|
||||||
optional_params=None,
|
|
||||||
logging_obj=None,
|
|
||||||
):
|
|
||||||
url = f"{api_base}/api/generate"
|
|
||||||
|
|
||||||
## Load Config
|
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||||
config=litellm.OllamaConfig.get_config()
|
try:
|
||||||
for k, v in config.items():
|
client = httpx.AsyncClient()
|
||||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
async with client.stream(
|
||||||
optional_params[k] = v
|
url=f"{url}",
|
||||||
|
json=data,
|
||||||
|
method="POST",
|
||||||
|
timeout=litellm.request_timeout
|
||||||
|
) as response:
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
data = {
|
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||||
"model": model,
|
async for transformed_chunk in streamwrapper:
|
||||||
"prompt": prompt,
|
yield transformed_chunk
|
||||||
**optional_params
|
except Exception as e:
|
||||||
}
|
traceback.print_exc()
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=None,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={"api_base": url, "complete_input_dict": data},
|
|
||||||
)
|
|
||||||
session = requests.Session()
|
|
||||||
|
|
||||||
with session.post(url, json=data, stream=True) as resp:
|
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||||
if resp.status_code != 200:
|
data["stream"] = False
|
||||||
raise OllamaError(status_code=resp.status_code, message=resp.text)
|
try:
|
||||||
for line in resp.iter_lines():
|
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
|
||||||
|
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||||
|
resp = await session.post(url, json=data)
|
||||||
|
|
||||||
|
if resp.status != 200:
|
||||||
|
text = await resp.text()
|
||||||
|
raise OllamaError(status_code=resp.status, message=text)
|
||||||
|
|
||||||
|
completion_string = ""
|
||||||
|
async for line in resp.content.iter_any():
|
||||||
if line:
|
if line:
|
||||||
try:
|
try:
|
||||||
json_chunk = line.decode("utf-8")
|
json_chunk = line.decode("utf-8")
|
||||||
|
@ -217,15 +227,24 @@ if async_generator_imported:
|
||||||
"content": "",
|
"content": "",
|
||||||
"error": j
|
"error": j
|
||||||
}
|
}
|
||||||
await yield_({"choices": [{"delta": completion_obj}]})
|
raise Exception(f"OllamError - {chunk}")
|
||||||
if "response" in j:
|
if "response" in j:
|
||||||
completion_obj = {
|
completion_obj = {
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "",
|
"content": j["response"],
|
||||||
}
|
}
|
||||||
completion_obj["content"] = j["response"]
|
completion_string = completion_string + completion_obj["content"]
|
||||||
await yield_({"choices": [{"delta": completion_obj}]})
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import logging
|
traceback.print_exc()
|
||||||
logging.debug(f"Error decoding JSON: {e}")
|
|
||||||
session.close()
|
## RESPONSE OBJECT
|
||||||
|
model_response["choices"][0]["finish_reason"] = "stop"
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_string
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = "ollama/" + data['model']
|
||||||
|
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
|
||||||
|
completion_tokens = len(encoding.encode(completion_string))
|
||||||
|
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||||
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
|
|
@ -195,23 +195,23 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=messages,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
if acompletion is True:
|
if acompletion is True:
|
||||||
if optional_params.get("stream", False):
|
if optional_params.get("stream", False):
|
||||||
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||||
else:
|
else:
|
||||||
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||||
elif optional_params.get("stream", False):
|
elif optional_params.get("stream", False):
|
||||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||||
else:
|
else:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
if client is None:
|
if client is None:
|
||||||
|
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||||
|
stringified_response = response.model_dump_json()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=None,
|
input=messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response,
|
original_response=stringified_response,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||||
|
@ -259,6 +260,8 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str]=None,
|
api_base: Optional[str]=None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
|
logging_obj=None,
|
||||||
|
headers=None
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -266,16 +269,23 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data['messages'],
|
||||||
|
api_key=openai_aclient.api_key,
|
||||||
|
additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||||
|
)
|
||||||
response = await openai_aclient.chat.completions.create(**data)
|
response = await openai_aclient.chat.completions.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
stringified_response = response.model_dump_json()
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=data['messages'],
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=stringified_response,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if response and hasattr(response, "text"):
|
raise e
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
|
||||||
else:
|
|
||||||
if type(e).__name__ == "ReadTimeout":
|
|
||||||
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
|
|
||||||
else:
|
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
|
||||||
|
|
||||||
def streaming(self,
|
def streaming(self,
|
||||||
logging_obj,
|
logging_obj,
|
||||||
|
@ -285,12 +295,19 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_key: Optional[str]=None,
|
api_key: Optional[str]=None,
|
||||||
api_base: Optional[str]=None,
|
api_base: Optional[str]=None,
|
||||||
client = None,
|
client = None,
|
||||||
max_retries=None
|
max_retries=None,
|
||||||
|
headers=None
|
||||||
):
|
):
|
||||||
if client is None:
|
if client is None:
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data['messages'],
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
|
||||||
|
)
|
||||||
response = openai_client.chat.completions.create(**data)
|
response = openai_client.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||||
return streamwrapper
|
return streamwrapper
|
||||||
|
@ -304,6 +321,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str]=None,
|
api_base: Optional[str]=None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
|
headers=None
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -311,6 +329,13 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(
|
||||||
|
input=data['messages'],
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
|
||||||
|
)
|
||||||
|
|
||||||
response = await openai_aclient.chat.completions.create(**data)
|
response = await openai_aclient.chat.completions.create(**data)
|
||||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||||
async for transformed_chunk in streamwrapper:
|
async for transformed_chunk in streamwrapper:
|
||||||
|
@ -325,6 +350,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
input: list,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
@ -332,6 +358,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str]=None,
|
api_base: Optional[str]=None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
|
logging_obj=None
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -340,9 +367,24 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
||||||
return response
|
stringified_response = response.model_dump_json()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=stringified_response,
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
|
@ -367,13 +409,6 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
if aembedding == True:
|
|
||||||
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
|
||||||
return response
|
|
||||||
if client is None:
|
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -381,6 +416,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
|
return response
|
||||||
|
if client is None:
|
||||||
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = openai_client.embeddings.create(**data) # type: ignore
|
response = openai_client.embeddings.create(**data) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -472,12 +515,14 @@ class OpenAITextCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||||
|
|
||||||
|
# don't send max retries to the api, if set
|
||||||
|
optional_params.pop("max_retries", None)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
**optional_params
|
**optional_params
|
||||||
}
|
}
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
|
|
@ -73,8 +73,27 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
|
||||||
final_prompt_value="### Response:",
|
final_prompt_value="### Response:",
|
||||||
messages=messages
|
messages=messages
|
||||||
)
|
)
|
||||||
|
elif "llava" in model:
|
||||||
|
prompt = ""
|
||||||
|
images = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
prompt += message["content"]
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||||
|
for element in message["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
prompt += element["text"]
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
images.append(image_url)
|
||||||
|
return {
|
||||||
|
"prompt": prompt,
|
||||||
|
"images": images
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
prompt = "".join(m["content"] for m in messages)
|
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def mistral_instruct_pt(messages):
|
def mistral_instruct_pt(messages):
|
||||||
|
@ -161,6 +180,8 @@ def phind_codellama_pt(messages):
|
||||||
|
|
||||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
||||||
## get the tokenizer config from huggingface
|
## get the tokenizer config from huggingface
|
||||||
|
bos_token = ""
|
||||||
|
eos_token = ""
|
||||||
if chat_template is None:
|
if chat_template is None:
|
||||||
def _get_tokenizer_config(hf_model_name):
|
def _get_tokenizer_config(hf_model_name):
|
||||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||||
|
@ -187,7 +208,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
||||||
# Create a template object from the template text
|
# Create a template object from the template text
|
||||||
env = Environment()
|
env = Environment()
|
||||||
env.globals['raise_exception'] = raise_exception
|
env.globals['raise_exception'] = raise_exception
|
||||||
template = env.from_string(chat_template)
|
try:
|
||||||
|
template = env.from_string(chat_template)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def _is_system_in_template():
|
def _is_system_in_template():
|
||||||
try:
|
try:
|
||||||
|
@ -227,8 +251,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
||||||
new_messages.append(reformatted_messages[-1])
|
new_messages.append(reformatted_messages[-1])
|
||||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
|
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
|
||||||
return rendered_text
|
return rendered_text
|
||||||
except:
|
except Exception as e:
|
||||||
raise Exception("Error rendering template")
|
raise Exception(f"Error rendering template - {str(e)}")
|
||||||
|
|
||||||
# Anthropic template
|
# Anthropic template
|
||||||
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
|
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
|
||||||
|
@ -266,20 +290,26 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
|
||||||
### TOGETHER AI
|
### TOGETHER AI
|
||||||
|
|
||||||
def get_model_info(token, model):
|
def get_model_info(token, model):
|
||||||
headers = {
|
try:
|
||||||
'Authorization': f'Bearer {token}'
|
headers = {
|
||||||
}
|
'Authorization': f'Bearer {token}'
|
||||||
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
}
|
||||||
if response.status_code == 200:
|
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||||
model_info = response.json()
|
if response.status_code == 200:
|
||||||
for m in model_info:
|
model_info = response.json()
|
||||||
if m["name"].lower().strip() == model.strip():
|
for m in model_info:
|
||||||
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
if m["name"].lower().strip() == model.strip():
|
||||||
return None, None
|
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||||
else:
|
return None, None
|
||||||
|
else:
|
||||||
|
return None, None
|
||||||
|
except Exception as e: # safely fail a prompt template request
|
||||||
return None, None
|
return None, None
|
||||||
|
|
||||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||||
|
if prompt_format is None:
|
||||||
|
return default_pt(messages)
|
||||||
|
|
||||||
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
||||||
|
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
|
|
|
@ -232,7 +232,8 @@ def completion(
|
||||||
if system_prompt is not None:
|
if system_prompt is not None:
|
||||||
input_data = {
|
input_data = {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"system_prompt": system_prompt
|
"system_prompt": system_prompt,
|
||||||
|
**optional_params
|
||||||
}
|
}
|
||||||
# Otherwise, use the prompt as is
|
# Otherwise, use the prompt as is
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -158,6 +158,7 @@ def completion(
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||||
|
|
||||||
response = response["Body"].read().decode("utf8")
|
response = response["Body"].read().decode("utf8")
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -171,10 +172,17 @@ def completion(
|
||||||
completion_response = json.loads(response)
|
completion_response = json.loads(response)
|
||||||
try:
|
try:
|
||||||
completion_response_choices = completion_response[0]
|
completion_response_choices = completion_response[0]
|
||||||
|
completion_output = ""
|
||||||
if "generation" in completion_response_choices:
|
if "generation" in completion_response_choices:
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
|
completion_output += completion_response_choices["generation"]
|
||||||
elif "generated_text" in completion_response_choices:
|
elif "generated_text" in completion_response_choices:
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
|
completion_output += completion_response_choices["generated_text"]
|
||||||
|
|
||||||
|
# check if the prompt template is part of output, if so - filter it out
|
||||||
|
if completion_output.startswith(prompt) and "<s>" in prompt:
|
||||||
|
completion_output = completion_output.replace(prompt, "", 1)
|
||||||
|
|
||||||
|
model_response["choices"][0]["message"]["content"] = completion_output
|
||||||
except:
|
except:
|
||||||
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||||
|
|
||||||
|
|
|
@ -173,10 +173,11 @@ def completion(
|
||||||
message=json.dumps(completion_response["output"]), status_code=response.status_code
|
message=json.dumps(completion_response["output"]), status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(completion_response["output"]["choices"][0]["text"]) > 0:
|
if len(completion_response["output"]["choices"][0]["text"]) >= 0:
|
||||||
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
|
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
|
||||||
|
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
|
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}")
|
||||||
prompt_tokens = len(encoding.encode(prompt))
|
prompt_tokens = len(encoding.encode(prompt))
|
||||||
completion_tokens = len(
|
completion_tokens = len(
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
|
|
@ -4,7 +4,7 @@ from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||||
import litellm
|
import litellm
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
@ -57,6 +57,108 @@ class VertexAIConfig():
|
||||||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||||
and v is not None}
|
and v is not None}
|
||||||
|
|
||||||
|
def _get_image_bytes_from_url(image_url: str) -> bytes:
|
||||||
|
try:
|
||||||
|
response = requests.get(image_url)
|
||||||
|
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx)
|
||||||
|
image_bytes = response.content
|
||||||
|
return image_bytes
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
# Handle any request exceptions (e.g., connection error, timeout)
|
||||||
|
return b'' # Return an empty bytes object or handle the error as needed
|
||||||
|
|
||||||
|
|
||||||
|
def _load_image_from_url(image_url: str):
|
||||||
|
"""
|
||||||
|
Loads an image from a URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_url (str): The URL of the image.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image: The loaded image.
|
||||||
|
"""
|
||||||
|
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
|
||||||
|
image_bytes = _get_image_bytes_from_url(image_url)
|
||||||
|
return Image.from_bytes(image_bytes)
|
||||||
|
|
||||||
|
def _gemini_vision_convert_messages(
|
||||||
|
messages: list
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Converts given messages for GPT-4 Vision to Gemini format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
|
||||||
|
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
|
||||||
|
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
|
||||||
|
Exception: If any other exception occurs during the execution of the function.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
|
||||||
|
The supported MIME types for images include 'image/png' and 'image/jpeg'.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> messages = [
|
||||||
|
... {"content": "Hello, world!"},
|
||||||
|
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
|
||||||
|
... ]
|
||||||
|
>>> _gemini_vision_convert_messages(messages)
|
||||||
|
('Hello, world!This is a text message.', [<Part object>, <Part object>])
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
import vertexai
|
||||||
|
except:
|
||||||
|
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
|
||||||
|
try:
|
||||||
|
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||||
|
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||||
|
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
|
||||||
|
|
||||||
|
# given messages for gpt-4 vision, convert them for gemini
|
||||||
|
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||||
|
prompt = ""
|
||||||
|
images = []
|
||||||
|
for message in messages:
|
||||||
|
if isinstance(message["content"], str):
|
||||||
|
prompt += message["content"]
|
||||||
|
elif isinstance(message["content"], list):
|
||||||
|
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||||
|
for element in message["content"]:
|
||||||
|
if isinstance(element, dict):
|
||||||
|
if element["type"] == "text":
|
||||||
|
prompt += element["text"]
|
||||||
|
elif element["type"] == "image_url":
|
||||||
|
image_url = element["image_url"]["url"]
|
||||||
|
images.append(image_url)
|
||||||
|
# processing images passed to gemini
|
||||||
|
processed_images = []
|
||||||
|
for img in images:
|
||||||
|
if "gs://" in img:
|
||||||
|
# Case 1: Images with Cloud Storage URIs
|
||||||
|
# The supported MIME types for images include image/png and image/jpeg.
|
||||||
|
part_mime = "image/png" if "png" in img else "image/jpeg"
|
||||||
|
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
||||||
|
processed_images.append(google_clooud_part)
|
||||||
|
elif "https:/" in img:
|
||||||
|
# Case 2: Images with direct links
|
||||||
|
image = _load_image_from_url(img)
|
||||||
|
processed_images.append(image)
|
||||||
|
elif ".mp4" in img and "gs://" in img:
|
||||||
|
# Case 3: Videos with Cloud Storage URIs
|
||||||
|
part_mime = "video/mp4"
|
||||||
|
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
||||||
|
processed_images.append(google_clooud_part)
|
||||||
|
return prompt, processed_images
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -69,6 +171,7 @@ def completion(
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
acompletion: bool=False
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
import vertexai
|
import vertexai
|
||||||
|
@ -77,6 +180,8 @@ def completion(
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||||
|
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
||||||
|
|
||||||
|
|
||||||
vertexai.init(
|
vertexai.init(
|
||||||
project=vertex_project, location=vertex_location
|
project=vertex_project, location=vertex_location
|
||||||
|
@ -90,34 +195,94 @@ def completion(
|
||||||
|
|
||||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||||
|
|
||||||
prompt = " ".join([message["content"] for message in messages])
|
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
|
||||||
|
|
||||||
mode = ""
|
mode = ""
|
||||||
|
|
||||||
request_str = ""
|
request_str = ""
|
||||||
if model in litellm.vertex_chat_models:
|
response_obj = None
|
||||||
chat_model = ChatModel.from_pretrained(model)
|
if model in litellm.vertex_language_models:
|
||||||
|
llm_model = GenerativeModel(model)
|
||||||
|
mode = ""
|
||||||
|
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||||
|
elif model in litellm.vertex_vision_models:
|
||||||
|
llm_model = GenerativeModel(model)
|
||||||
|
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||||
|
mode = "vision"
|
||||||
|
elif model in litellm.vertex_chat_models:
|
||||||
|
llm_model = ChatModel.from_pretrained(model)
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
request_str += f"chat_model = ChatModel.from_pretrained({model})\n"
|
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
|
||||||
elif model in litellm.vertex_text_models:
|
elif model in litellm.vertex_text_models:
|
||||||
text_model = TextGenerationModel.from_pretrained(model)
|
llm_model = TextGenerationModel.from_pretrained(model)
|
||||||
mode = "text"
|
mode = "text"
|
||||||
request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n"
|
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
|
||||||
elif model in litellm.vertex_code_text_models:
|
elif model in litellm.vertex_code_text_models:
|
||||||
text_model = CodeGenerationModel.from_pretrained(model)
|
llm_model = CodeGenerationModel.from_pretrained(model)
|
||||||
mode = "text"
|
mode = "text"
|
||||||
request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n"
|
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||||
else: # vertex_code_chat_models
|
else: # vertex_code_llm_models
|
||||||
chat_model = CodeChatModel.from_pretrained(model)
|
llm_model = CodeChatModel.from_pretrained(model)
|
||||||
mode = "chat"
|
mode = "chat"
|
||||||
request_str += f"chat_model = CodeChatModel.from_pretrained({model})\n"
|
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||||
|
|
||||||
if mode == "chat":
|
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
|
||||||
chat = chat_model.start_chat()
|
if optional_params.get("stream", False) is True:
|
||||||
request_str+= f"chat = chat_model.start_chat()\n"
|
# async streaming
|
||||||
|
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params)
|
||||||
|
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params)
|
||||||
|
|
||||||
|
if mode == "":
|
||||||
|
chat = llm_model.start_chat()
|
||||||
|
request_str+= f"chat = llm_model.start_chat()\n"
|
||||||
|
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
stream = optional_params.pop("stream")
|
||||||
|
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||||
|
optional_params["stream"] = True
|
||||||
|
return model_response
|
||||||
|
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
|
||||||
|
completion_response = response_obj.text
|
||||||
|
response_obj = response_obj._raw_response
|
||||||
|
elif mode == "vision":
|
||||||
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
|
||||||
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
|
content = [prompt] + images
|
||||||
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
|
stream = optional_params.pop("stream")
|
||||||
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
|
||||||
|
model_response = llm_model.generate_content(
|
||||||
|
contents=content,
|
||||||
|
generation_config=GenerationConfig(**optional_params),
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
optional_params["stream"] = True
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
request_str += f"response = llm_model.generate_content({content})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
|
||||||
|
## LLM Call
|
||||||
|
response = llm_model.generate_content(
|
||||||
|
contents=content,
|
||||||
|
generation_config=GenerationConfig(**optional_params)
|
||||||
|
)
|
||||||
|
completion_response = response.text
|
||||||
|
response_obj = response._raw_response
|
||||||
|
elif mode == "chat":
|
||||||
|
chat = llm_model.start_chat()
|
||||||
|
request_str+= f"chat = llm_model.start_chat()\n"
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||||
|
@ -125,27 +290,30 @@ def completion(
|
||||||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||||
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||||
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
||||||
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
||||||
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
completion_response = chat.send_message(prompt, **optional_params).text
|
completion_response = chat.send_message(prompt, **optional_params).text
|
||||||
elif mode == "text":
|
elif mode == "text":
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||||
request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n"
|
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||||
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
model_response = text_model.predict_streaming(prompt, **optional_params)
|
model_response = llm_model.predict_streaming(prompt, **optional_params)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
request_str += f"text_model.predict({prompt}, **{optional_params}).text\n"
|
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||||
|
## LOGGING
|
||||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
completion_response = text_model.predict(prompt, **optional_params).text
|
completion_response = llm_model.predict(prompt, **optional_params).text
|
||||||
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
@ -161,22 +329,162 @@ def completion(
|
||||||
model_response["created"] = int(time.time())
|
model_response["created"] = int(time.time())
|
||||||
model_response["model"] = model
|
model_response["model"] = model
|
||||||
## CALCULATING USAGE
|
## CALCULATING USAGE
|
||||||
prompt_tokens = len(
|
if model in litellm.vertex_language_models and response_obj is not None:
|
||||||
encoding.encode(prompt)
|
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||||
)
|
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||||
completion_tokens = len(
|
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||||
)
|
else:
|
||||||
usage = Usage(
|
prompt_tokens = len(
|
||||||
prompt_tokens=prompt_tokens,
|
encoding.encode(prompt)
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens
|
|
||||||
)
|
)
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
)
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
return model_response
|
return model_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise VertexAIError(status_code=500, message=str(e))
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
|
||||||
|
"""
|
||||||
|
Add support for acompletion calls for gemini-pro
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
|
if mode == "":
|
||||||
|
# gemini-pro
|
||||||
|
chat = llm_model.start_chat()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
|
||||||
|
completion_response = response_obj.text
|
||||||
|
response_obj = response_obj._raw_response
|
||||||
|
elif mode == "vision":
|
||||||
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
|
||||||
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
|
content = [prompt] + images
|
||||||
|
|
||||||
|
request_str += f"response = llm_model.generate_content({content})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
|
||||||
|
## LLM Call
|
||||||
|
response = await llm_model._generate_content_async(
|
||||||
|
contents=content,
|
||||||
|
generation_config=GenerationConfig(**optional_params)
|
||||||
|
)
|
||||||
|
completion_response = response.text
|
||||||
|
response_obj = response._raw_response
|
||||||
|
elif mode == "chat":
|
||||||
|
# chat-bison etc.
|
||||||
|
chat = llm_model.start_chat()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response_obj = await chat.send_message_async(prompt, **optional_params)
|
||||||
|
completion_response = response_obj.text
|
||||||
|
elif mode == "text":
|
||||||
|
# gecko etc.
|
||||||
|
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
||||||
|
completion_response = response_obj.text
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt, api_key=None, original_response=completion_response
|
||||||
|
)
|
||||||
|
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
if len(str(completion_response)) > 0:
|
||||||
|
model_response["choices"][0]["message"][
|
||||||
|
"content"
|
||||||
|
] = str(completion_response)
|
||||||
|
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = model
|
||||||
|
## CALCULATING USAGE
|
||||||
|
if model in litellm.vertex_language_models and response_obj is not None:
|
||||||
|
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||||
|
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||||
|
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||||
|
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||||
|
else:
|
||||||
|
prompt_tokens = len(
|
||||||
|
encoding.encode(prompt)
|
||||||
|
)
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
)
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
return model_response
|
||||||
|
except Exception as e:
|
||||||
|
raise VertexAIError(status_code=500, message=str(e))
|
||||||
|
|
||||||
|
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
|
||||||
|
"""
|
||||||
|
Add support for async streaming calls for gemini-pro
|
||||||
|
"""
|
||||||
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
if mode == "":
|
||||||
|
# gemini-pro
|
||||||
|
chat = llm_model.start_chat()
|
||||||
|
stream = optional_params.pop("stream")
|
||||||
|
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||||
|
optional_params["stream"] = True
|
||||||
|
elif mode == "vision":
|
||||||
|
stream = optional_params.pop("stream")
|
||||||
|
|
||||||
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
|
||||||
|
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||||
|
content = [prompt] + images
|
||||||
|
stream = optional_params.pop("stream")
|
||||||
|
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
|
||||||
|
response = llm_model._generate_content_streaming_async(
|
||||||
|
contents=content,
|
||||||
|
generation_config=GenerationConfig(**optional_params),
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
optional_params["stream"] = True
|
||||||
|
elif mode == "chat":
|
||||||
|
chat = llm_model.start_chat()
|
||||||
|
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||||
|
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response = chat.send_message_streaming_async(prompt, **optional_params)
|
||||||
|
optional_params["stream"] = True
|
||||||
|
elif mode == "text":
|
||||||
|
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||||
|
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||||
|
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||||
|
|
||||||
|
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
|
||||||
|
async for transformed_chunk in streamwrapper:
|
||||||
|
yield transformed_chunk
|
||||||
|
|
||||||
def embedding():
|
def embedding():
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
|
|
251
litellm/main.py
251
litellm/main.py
|
@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
import httpx
|
import httpx
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
from litellm import ( # type: ignore
|
from litellm import ( # type: ignore
|
||||||
client,
|
client,
|
||||||
exception_type,
|
exception_type,
|
||||||
|
@ -31,7 +32,8 @@ from litellm.utils import (
|
||||||
mock_completion_streaming_obj,
|
mock_completion_streaming_obj,
|
||||||
convert_to_model_response_object,
|
convert_to_model_response_object,
|
||||||
token_counter,
|
token_counter,
|
||||||
Usage
|
Usage,
|
||||||
|
get_optional_params_embeddings
|
||||||
)
|
)
|
||||||
from .llms import (
|
from .llms import (
|
||||||
anthropic,
|
anthropic,
|
||||||
|
@ -171,11 +173,14 @@ async def acompletion(*args, **kwargs):
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "openrouter"
|
or custom_llm_provider == "openrouter"
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "text-completion-openai"
|
or custom_llm_provider == "text-completion-openai"
|
||||||
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
or custom_llm_provider == "huggingface"
|
||||||
|
or custom_llm_provider == "ollama"
|
||||||
|
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
if kwargs.get("stream", False):
|
if kwargs.get("stream", False):
|
||||||
response = completion(*args, **kwargs)
|
response = completion(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
|
@ -200,9 +205,12 @@ async def acompletion(*args, **kwargs):
|
||||||
|
|
||||||
async def _async_streaming(response, model, custom_llm_provider, args):
|
async def _async_streaming(response, model, custom_llm_provider, args):
|
||||||
try:
|
try:
|
||||||
|
print_verbose(f"received response in _async_streaming: {response}")
|
||||||
async for line in response:
|
async for line in response:
|
||||||
|
print_verbose(f"line in async streaming: {line}")
|
||||||
yield line
|
yield line
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print_verbose(f"error raised _async_streaming: {traceback.format_exc()}")
|
||||||
raise exception_type(
|
raise exception_type(
|
||||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||||
)
|
)
|
||||||
|
@ -278,7 +286,7 @@ def completion(
|
||||||
|
|
||||||
# Optional liteLLM function params
|
# Optional liteLLM function params
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> ModelResponse:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
"""
|
"""
|
||||||
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||||
Parameters:
|
Parameters:
|
||||||
|
@ -319,7 +327,6 @@ def completion(
|
||||||
######### unpacking kwargs #####################
|
######### unpacking kwargs #####################
|
||||||
args = locals()
|
args = locals()
|
||||||
api_base = kwargs.get('api_base', None)
|
api_base = kwargs.get('api_base', None)
|
||||||
return_async = kwargs.get('return_async', False)
|
|
||||||
mock_response = kwargs.get('mock_response', None)
|
mock_response = kwargs.get('mock_response', None)
|
||||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||||
logger_fn = kwargs.get('logger_fn', None)
|
logger_fn = kwargs.get('logger_fn', None)
|
||||||
|
@ -344,13 +351,14 @@ def completion(
|
||||||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||||
bos_token = kwargs.get("bos_token", None)
|
bos_token = kwargs.get("bos_token", None)
|
||||||
eos_token = kwargs.get("eos_token", None)
|
eos_token = kwargs.get("eos_token", None)
|
||||||
|
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||||
hf_model_name = kwargs.get("hf_model_name", None)
|
hf_model_name = kwargs.get("hf_model_name", None)
|
||||||
### ASYNC CALLS ###
|
### ASYNC CALLS ###
|
||||||
acompletion = kwargs.get("acompletion", False)
|
acompletion = kwargs.get("acompletion", False)
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request"]
|
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
if mock_response:
|
if mock_response:
|
||||||
|
@ -384,7 +392,6 @@ def completion(
|
||||||
model=deployment_id
|
model=deployment_id
|
||||||
custom_llm_provider="azure"
|
custom_llm_provider="azure"
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||||
|
|
||||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||||
litellm.register_model({
|
litellm.register_model({
|
||||||
|
@ -448,7 +455,6 @@ def completion(
|
||||||
# For logging - save the values of the litellm-specific params passed in
|
# For logging - save the values of the litellm-specific params passed in
|
||||||
litellm_params = get_litellm_params(
|
litellm_params = get_litellm_params(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
return_async=return_async,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
force_timeout=force_timeout,
|
force_timeout=force_timeout,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
|
@ -460,7 +466,8 @@ def completion(
|
||||||
completion_call_id=id,
|
completion_call_id=id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
proxy_server_request=proxy_server_request
|
proxy_server_request=proxy_server_request,
|
||||||
|
preset_cache_key=preset_cache_key
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
|
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
|
@ -524,23 +531,25 @@ def completion(
|
||||||
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
logging.post_call(
|
## LOGGING
|
||||||
input=messages,
|
logging.post_call(
|
||||||
api_key=api_key,
|
input=messages,
|
||||||
original_response=response,
|
api_key=api_key,
|
||||||
additional_args={
|
original_response=response,
|
||||||
"headers": headers,
|
additional_args={
|
||||||
"api_version": api_version,
|
"headers": headers,
|
||||||
"api_base": api_base,
|
"api_version": api_version,
|
||||||
},
|
"api_base": api_base,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
model in litellm.open_ai_chat_completion_models
|
model in litellm.open_ai_chat_completion_models
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
|
or custom_llm_provider == "mistral"
|
||||||
or custom_llm_provider == "openai"
|
or custom_llm_provider == "openai"
|
||||||
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
|
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
|
||||||
): # allow user to make an openai call with a custom base
|
): # allow user to make an openai call with a custom base
|
||||||
|
@ -604,19 +613,19 @@ def completion(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
## LOGGING
|
if optional_params.get("stream", False):
|
||||||
logging.post_call(
|
## LOGGING
|
||||||
input=messages,
|
logging.post_call(
|
||||||
api_key=api_key,
|
input=messages,
|
||||||
original_response=response,
|
api_key=api_key,
|
||||||
additional_args={"headers": headers},
|
original_response=response,
|
||||||
)
|
additional_args={"headers": headers},
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "text-completion-openai"
|
custom_llm_provider == "text-completion-openai"
|
||||||
or "ft:babbage-002" in model
|
or "ft:babbage-002" in model
|
||||||
or "ft:davinci-002" in model # support for finetuned completion models
|
or "ft:davinci-002" in model # support for finetuned completion models
|
||||||
):
|
):
|
||||||
# print("calling custom openai provider")
|
|
||||||
openai.api_type = "openai"
|
openai.api_type = "openai"
|
||||||
|
|
||||||
api_base = (
|
api_base = (
|
||||||
|
@ -655,17 +664,6 @@ def completion(
|
||||||
prompt = messages[0]["content"]
|
prompt = messages[0]["content"]
|
||||||
else:
|
else:
|
||||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||||
## LOGGING
|
|
||||||
logging.pre_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={
|
|
||||||
"openai_organization": litellm.organization,
|
|
||||||
"headers": headers,
|
|
||||||
"api_base": api_base,
|
|
||||||
"api_type": openai.api_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
model_response = openai_text_completions.completion(
|
model_response = openai_text_completions.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -681,9 +679,14 @@ def completion(
|
||||||
logger_fn=logger_fn
|
logger_fn=logger_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
# if "stream" in optional_params and optional_params["stream"] == True:
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
|
## LOGGING
|
||||||
# return response
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=model_response,
|
||||||
|
additional_args={"headers": headers},
|
||||||
|
)
|
||||||
response = model_response
|
response = model_response
|
||||||
elif (
|
elif (
|
||||||
"replicate" in model or
|
"replicate" in model or
|
||||||
|
@ -728,8 +731,16 @@ def completion(
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
|
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
||||||
return response
|
|
||||||
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=replicate_key,
|
||||||
|
original_response=model_response,
|
||||||
|
)
|
||||||
|
|
||||||
response = model_response
|
response = model_response
|
||||||
|
|
||||||
elif custom_llm_provider=="anthropic":
|
elif custom_llm_provider=="anthropic":
|
||||||
|
@ -749,7 +760,7 @@ def completion(
|
||||||
custom_prompt_dict
|
custom_prompt_dict
|
||||||
or litellm.custom_prompt_dict
|
or litellm.custom_prompt_dict
|
||||||
)
|
)
|
||||||
model_response = anthropic.completion(
|
response = anthropic.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -765,9 +776,16 @@ def completion(
|
||||||
)
|
)
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||||
return response
|
|
||||||
response = model_response
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
response = response
|
||||||
elif custom_llm_provider == "nlp_cloud":
|
elif custom_llm_provider == "nlp_cloud":
|
||||||
nlp_cloud_key = (
|
nlp_cloud_key = (
|
||||||
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
|
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
|
||||||
|
@ -780,7 +798,7 @@ def completion(
|
||||||
or "https://api.nlpcloud.io/v1/gpu/"
|
or "https://api.nlpcloud.io/v1/gpu/"
|
||||||
)
|
)
|
||||||
|
|
||||||
model_response = nlp_cloud.completion(
|
response = nlp_cloud.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
|
@ -796,9 +814,17 @@ def completion(
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True:
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||||
return response
|
|
||||||
response = model_response
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = response
|
||||||
elif custom_llm_provider == "aleph_alpha":
|
elif custom_llm_provider == "aleph_alpha":
|
||||||
aleph_alpha_key = (
|
aleph_alpha_key = (
|
||||||
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
|
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
|
||||||
|
@ -1100,7 +1126,7 @@ def completion(
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models:
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = (litellm.vertex_project
|
vertex_ai_project = (litellm.vertex_project
|
||||||
or get_secret("VERTEXAI_PROJECT"))
|
or get_secret("VERTEXAI_PROJECT"))
|
||||||
vertex_ai_location = (litellm.vertex_location
|
vertex_ai_location = (litellm.vertex_location
|
||||||
|
@ -1117,10 +1143,11 @@ def completion(
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
vertex_project=vertex_ai_project,
|
vertex_project=vertex_ai_project,
|
||||||
logging_obj=logging
|
logging_obj=logging,
|
||||||
|
acompletion=acompletion
|
||||||
)
|
)
|
||||||
|
|
||||||
if "stream" in optional_params and optional_params["stream"] == True:
|
if "stream" in optional_params and optional_params["stream"] == True and acompletion == False:
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
|
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
|
||||||
)
|
)
|
||||||
|
@ -1186,6 +1213,7 @@ def completion(
|
||||||
# "SageMaker is currently not supporting streaming responses."
|
# "SageMaker is currently not supporting streaming responses."
|
||||||
|
|
||||||
# fake streaming for sagemaker
|
# fake streaming for sagemaker
|
||||||
|
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||||
resp_string = model_response["choices"][0]["message"]["content"]
|
resp_string = model_response["choices"][0]["message"]["content"]
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging
|
resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging
|
||||||
|
@ -1200,7 +1228,7 @@ def completion(
|
||||||
custom_prompt_dict
|
custom_prompt_dict
|
||||||
or litellm.custom_prompt_dict
|
or litellm.custom_prompt_dict
|
||||||
)
|
)
|
||||||
model_response = bedrock.completion(
|
response = bedrock.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||||
|
@ -1218,16 +1246,24 @@ def completion(
|
||||||
# don't try to access stream object,
|
# don't try to access stream object,
|
||||||
if "ai21" in model:
|
if "ai21" in model:
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
|
response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||||
)
|
)
|
||||||
return response
|
|
||||||
|
if optional_params.get("stream", False):
|
||||||
|
## LOGGING
|
||||||
|
logging.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=None,
|
||||||
|
original_response=response,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
response = model_response
|
response = response
|
||||||
elif custom_llm_provider == "vllm":
|
elif custom_llm_provider == "vllm":
|
||||||
model_response = vllm.completion(
|
model_response = vllm.completion(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1273,14 +1309,18 @@ def completion(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||||
## LOGGING
|
if isinstance(prompt, dict):
|
||||||
if kwargs.get('acompletion', False) == True:
|
# for multimode models - ollama/llava prompt_factory returns a dict {
|
||||||
if optional_params.get("stream", False) == True:
|
# "prompt": prompt,
|
||||||
# assume all ollama responses are streamed
|
# "images": images
|
||||||
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
# }
|
||||||
return async_generator
|
prompt, images = prompt["prompt"], prompt["images"]
|
||||||
|
optional_params["images"] = images
|
||||||
|
|
||||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
## LOGGING
|
||||||
|
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||||
|
if acompletion is True:
|
||||||
|
return generator
|
||||||
if optional_params.get("stream", False) == True:
|
if optional_params.get("stream", False) == True:
|
||||||
# assume all ollama responses are streamed
|
# assume all ollama responses are streamed
|
||||||
response = CustomStreamWrapper(
|
response = CustomStreamWrapper(
|
||||||
|
@ -1716,8 +1756,7 @@ async def aembedding(*args, **kwargs):
|
||||||
or custom_llm_provider == "anyscale"
|
or custom_llm_provider == "anyscale"
|
||||||
or custom_llm_provider == "openrouter"
|
or custom_llm_provider == "openrouter"
|
||||||
or custom_llm_provider == "deepinfra"
|
or custom_llm_provider == "deepinfra"
|
||||||
or custom_llm_provider == "perplexity"
|
or custom_llm_provider == "perplexity"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||||
|
@ -1781,22 +1820,21 @@ def embedding(
|
||||||
rpm = kwargs.pop("rpm", None)
|
rpm = kwargs.pop("rpm", None)
|
||||||
tpm = kwargs.pop("tpm", None)
|
tpm = kwargs.pop("tpm", None)
|
||||||
model_info = kwargs.get("model_info", None)
|
model_info = kwargs.get("model_info", None)
|
||||||
|
metadata = kwargs.get("metadata", None)
|
||||||
|
encoding_format = kwargs.get("encoding_format", None)
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.pop("aembedding", None)
|
aembedding = kwargs.get("aembedding", None)
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info"]
|
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
optional_params = {}
|
|
||||||
for param in non_default_params:
|
|
||||||
optional_params[param] = kwargs[param]
|
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||||
|
optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = None
|
response = None
|
||||||
logging = litellm_logging_obj
|
logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info})
|
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}})
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure == True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
@ -1936,7 +1974,7 @@ def embedding(
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=openai.api_key,
|
api_key=api_key,
|
||||||
original_response=str(e),
|
original_response=str(e),
|
||||||
)
|
)
|
||||||
## Map to OpenAI Exception
|
## Map to OpenAI Exception
|
||||||
|
@ -1948,6 +1986,59 @@ def embedding(
|
||||||
|
|
||||||
|
|
||||||
###### Text Completion ################
|
###### Text Completion ################
|
||||||
|
async def atext_completion(*args, **kwargs):
|
||||||
|
"""
|
||||||
|
Implemented to handle async streaming for the text completion endpoint
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||||
|
### PASS ARGS TO COMPLETION ###
|
||||||
|
kwargs["acompletion"] = True
|
||||||
|
custom_llm_provider = None
|
||||||
|
try:
|
||||||
|
# Use a partial function to pass your keyword arguments
|
||||||
|
func = partial(text_completion, *args, **kwargs)
|
||||||
|
|
||||||
|
# Add the context to the function
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
|
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||||
|
|
||||||
|
if (custom_llm_provider == "openai"
|
||||||
|
or custom_llm_provider == "azure"
|
||||||
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "anyscale"
|
||||||
|
or custom_llm_provider == "mistral"
|
||||||
|
or custom_llm_provider == "openrouter"
|
||||||
|
or custom_llm_provider == "deepinfra"
|
||||||
|
or custom_llm_provider == "perplexity"
|
||||||
|
or custom_llm_provider == "text-completion-openai"
|
||||||
|
or custom_llm_provider == "huggingface"
|
||||||
|
or custom_llm_provider == "ollama"
|
||||||
|
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
|
if kwargs.get("stream", False):
|
||||||
|
response = text_completion(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
# Await normally
|
||||||
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||||
|
response = init_response
|
||||||
|
elif asyncio.iscoroutine(init_response):
|
||||||
|
response = await init_response
|
||||||
|
else:
|
||||||
|
# Call the synchronous function using run_in_executor
|
||||||
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
if kwargs.get("stream", False): # return an async generator
|
||||||
|
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||||
|
else:
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
|
raise exception_type(
|
||||||
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||||
|
)
|
||||||
|
|
||||||
def text_completion(
|
def text_completion(
|
||||||
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
||||||
model: Optional[str]=None, # Optional: either `model` or `engine` can be set
|
model: Optional[str]=None, # Optional: either `model` or `engine` can be set
|
||||||
|
@ -2079,7 +2170,7 @@ def text_completion(
|
||||||
*args,
|
*args,
|
||||||
**all_params,
|
**all_params,
|
||||||
)
|
)
|
||||||
#print(response)
|
|
||||||
text_completion_response["id"] = response.get("id", None)
|
text_completion_response["id"] = response.get("id", None)
|
||||||
text_completion_response["object"] = "text_completion"
|
text_completion_response["object"] = "text_completion"
|
||||||
text_completion_response["created"] = response.get("created", None)
|
text_completion_response["created"] = response.get("created", None)
|
||||||
|
@ -2294,6 +2385,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
|
||||||
completion_output = combined_content
|
completion_output = combined_content
|
||||||
elif len(combined_arguments) > 0:
|
elif len(combined_arguments) > 0:
|
||||||
completion_output = combined_arguments
|
completion_output = combined_arguments
|
||||||
|
else:
|
||||||
|
completion_output = ""
|
||||||
# # Update usage information if needed
|
# # Update usage information if needed
|
||||||
try:
|
try:
|
||||||
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
|
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
|
||||||
|
|
|
@ -41,6 +41,20 @@
|
||||||
"litellm_provider": "openai",
|
"litellm_provider": "openai",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"gpt-4-1106-preview": {
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"gpt-4-vision-preview": {
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"gpt-3.5-turbo": {
|
"gpt-3.5-turbo": {
|
||||||
"max_tokens": 4097,
|
"max_tokens": 4097,
|
||||||
"input_cost_per_token": 0.0000015,
|
"input_cost_per_token": 0.0000015,
|
||||||
|
@ -62,6 +76,13 @@
|
||||||
"litellm_provider": "openai",
|
"litellm_provider": "openai",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"gpt-3.5-turbo-1106": {
|
||||||
|
"max_tokens": 16385,
|
||||||
|
"input_cost_per_token": 0.0000010,
|
||||||
|
"output_cost_per_token": 0.0000020,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"gpt-3.5-turbo-16k": {
|
"gpt-3.5-turbo-16k": {
|
||||||
"max_tokens": 16385,
|
"max_tokens": 16385,
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.000003,
|
||||||
|
@ -76,6 +97,62 @@
|
||||||
"litellm_provider": "openai",
|
"litellm_provider": "openai",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"ft:gpt-3.5-turbo": {
|
||||||
|
"max_tokens": 4097,
|
||||||
|
"input_cost_per_token": 0.000012,
|
||||||
|
"output_cost_per_token": 0.000016,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"text-embedding-ada-002": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
|
"azure/gpt-4-1106-preview": {
|
||||||
|
"max_tokens": 128000,
|
||||||
|
"input_cost_per_token": 0.00001,
|
||||||
|
"output_cost_per_token": 0.00003,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"azure/gpt-4-32k": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00006,
|
||||||
|
"output_cost_per_token": 0.00012,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"azure/gpt-4": {
|
||||||
|
"max_tokens": 16385,
|
||||||
|
"input_cost_per_token": 0.00003,
|
||||||
|
"output_cost_per_token": 0.00006,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"azure/gpt-3.5-turbo-16k": {
|
||||||
|
"max_tokens": 16385,
|
||||||
|
"input_cost_per_token": 0.000003,
|
||||||
|
"output_cost_per_token": 0.000004,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"azure/gpt-3.5-turbo": {
|
||||||
|
"max_tokens": 4097,
|
||||||
|
"input_cost_per_token": 0.0000015,
|
||||||
|
"output_cost_per_token": 0.000002,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"azure/text-embedding-ada-002": {
|
||||||
|
"max_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.0000001,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "azure",
|
||||||
|
"mode": "embedding"
|
||||||
|
},
|
||||||
"text-davinci-003": {
|
"text-davinci-003": {
|
||||||
"max_tokens": 4097,
|
"max_tokens": 4097,
|
||||||
"input_cost_per_token": 0.000002,
|
"input_cost_per_token": 0.000002,
|
||||||
|
@ -127,6 +204,7 @@
|
||||||
},
|
},
|
||||||
"claude-instant-1": {
|
"claude-instant-1": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
"input_cost_per_token": 0.00000163,
|
"input_cost_per_token": 0.00000163,
|
||||||
"output_cost_per_token": 0.00000551,
|
"output_cost_per_token": 0.00000551,
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
|
@ -134,15 +212,25 @@
|
||||||
},
|
},
|
||||||
"claude-instant-1.2": {
|
"claude-instant-1.2": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
"input_cost_per_token": 0.00000163,
|
"max_output_tokens": 8191,
|
||||||
"output_cost_per_token": 0.00000551,
|
"input_cost_per_token": 0.000000163,
|
||||||
|
"output_cost_per_token": 0.000000551,
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"claude-2": {
|
"claude-2": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
"input_cost_per_token": 0.00001102,
|
"max_output_tokens": 8191,
|
||||||
"output_cost_per_token": 0.00003268,
|
"input_cost_per_token": 0.000008,
|
||||||
|
"output_cost_per_token": 0.000024,
|
||||||
|
"litellm_provider": "anthropic",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"claude-2.1": {
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.000008,
|
||||||
|
"output_cost_per_token": 0.000024,
|
||||||
"litellm_provider": "anthropic",
|
"litellm_provider": "anthropic",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
@ -227,9 +315,51 @@
|
||||||
"max_tokens": 32000,
|
"max_tokens": 32000,
|
||||||
"input_cost_per_token": 0.000000125,
|
"input_cost_per_token": 0.000000125,
|
||||||
"output_cost_per_token": 0.000000125,
|
"output_cost_per_token": 0.000000125,
|
||||||
"litellm_provider": "vertex_ai-chat-models",
|
"litellm_provider": "vertex_ai-code-chat-models",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"palm/chat-bison": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000125,
|
||||||
|
"litellm_provider": "palm",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"palm/chat-bison-001": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000125,
|
||||||
|
"litellm_provider": "palm",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"palm/text-bison": {
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000125,
|
||||||
|
"litellm_provider": "palm",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"palm/text-bison-001": {
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000125,
|
||||||
|
"litellm_provider": "palm",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"palm/text-bison-safety-off": {
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000125,
|
||||||
|
"litellm_provider": "palm",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"palm/text-bison-safety-recitation-off": {
|
||||||
|
"max_tokens": 8196,
|
||||||
|
"input_cost_per_token": 0.000000125,
|
||||||
|
"output_cost_per_token": 0.000000125,
|
||||||
|
"litellm_provider": "palm",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
"command-nightly": {
|
"command-nightly": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000015,
|
"input_cost_per_token": 0.000015,
|
||||||
|
@ -267,6 +397,8 @@
|
||||||
},
|
},
|
||||||
"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": {
|
"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000,
|
||||||
|
"output_cost_per_token": 0.0000,
|
||||||
"litellm_provider": "replicate",
|
"litellm_provider": "replicate",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
@ -293,6 +425,7 @@
|
||||||
},
|
},
|
||||||
"openrouter/anthropic/claude-instant-v1": {
|
"openrouter/anthropic/claude-instant-v1": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
"input_cost_per_token": 0.00000163,
|
"input_cost_per_token": 0.00000163,
|
||||||
"output_cost_per_token": 0.00000551,
|
"output_cost_per_token": 0.00000551,
|
||||||
"litellm_provider": "openrouter",
|
"litellm_provider": "openrouter",
|
||||||
|
@ -300,6 +433,7 @@
|
||||||
},
|
},
|
||||||
"openrouter/anthropic/claude-2": {
|
"openrouter/anthropic/claude-2": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
"input_cost_per_token": 0.00001102,
|
"input_cost_per_token": 0.00001102,
|
||||||
"output_cost_per_token": 0.00003268,
|
"output_cost_per_token": 0.00003268,
|
||||||
"litellm_provider": "openrouter",
|
"litellm_provider": "openrouter",
|
||||||
|
@ -496,20 +630,31 @@
|
||||||
},
|
},
|
||||||
"anthropic.claude-v1": {
|
"anthropic.claude-v1": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
"input_cost_per_token": 0.00001102,
|
"max_output_tokens": 8191,
|
||||||
"output_cost_per_token": 0.00003268,
|
"input_cost_per_token": 0.000008,
|
||||||
|
"output_cost_per_token": 0.000024,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"anthropic.claude-v2": {
|
"anthropic.claude-v2": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
"input_cost_per_token": 0.00001102,
|
"max_output_tokens": 8191,
|
||||||
"output_cost_per_token": 0.00003268,
|
"input_cost_per_token": 0.000008,
|
||||||
|
"output_cost_per_token": 0.000024,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anthropic.claude-v2:1": {
|
||||||
|
"max_tokens": 200000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
|
"input_cost_per_token": 0.000008,
|
||||||
|
"output_cost_per_token": 0.000024,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
"anthropic.claude-instant-v1": {
|
"anthropic.claude-instant-v1": {
|
||||||
"max_tokens": 100000,
|
"max_tokens": 100000,
|
||||||
|
"max_output_tokens": 8191,
|
||||||
"input_cost_per_token": 0.00000163,
|
"input_cost_per_token": 0.00000163,
|
||||||
"output_cost_per_token": 0.00000551,
|
"output_cost_per_token": 0.00000551,
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
|
@ -529,26 +674,80 @@
|
||||||
"litellm_provider": "bedrock",
|
"litellm_provider": "bedrock",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
"meta.llama2-70b-chat-v1": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000195,
|
||||||
|
"output_cost_per_token": 0.00000256,
|
||||||
|
"litellm_provider": "bedrock",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000,
|
||||||
|
"output_cost_per_token": 0.000,
|
||||||
|
"litellm_provider": "sagemaker",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"sagemaker/meta-textgeneration-llama-2-7b-f": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000,
|
||||||
|
"output_cost_per_token": 0.000,
|
||||||
|
"litellm_provider": "sagemaker",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"sagemaker/meta-textgeneration-llama-2-13b": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000,
|
||||||
|
"output_cost_per_token": 0.000,
|
||||||
|
"litellm_provider": "sagemaker",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"sagemaker/meta-textgeneration-llama-2-13b-f": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000,
|
||||||
|
"output_cost_per_token": 0.000,
|
||||||
|
"litellm_provider": "sagemaker",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"sagemaker/meta-textgeneration-llama-2-70b": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000,
|
||||||
|
"output_cost_per_token": 0.000,
|
||||||
|
"litellm_provider": "sagemaker",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"sagemaker/meta-textgeneration-llama-2-70b-b-f": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000,
|
||||||
|
"output_cost_per_token": 0.000,
|
||||||
|
"litellm_provider": "sagemaker",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
"together-ai-up-to-3b": {
|
"together-ai-up-to-3b": {
|
||||||
"input_cost_per_token": 0.0000001,
|
"input_cost_per_token": 0.0000001,
|
||||||
"output_cost_per_token": 0.0000001
|
"output_cost_per_token": 0.0000001,
|
||||||
|
"litellm_provider": "together_ai"
|
||||||
},
|
},
|
||||||
"together-ai-3.1b-7b": {
|
"together-ai-3.1b-7b": {
|
||||||
"input_cost_per_token": 0.0000002,
|
"input_cost_per_token": 0.0000002,
|
||||||
"output_cost_per_token": 0.0000002
|
"output_cost_per_token": 0.0000002,
|
||||||
|
"litellm_provider": "together_ai"
|
||||||
},
|
},
|
||||||
"together-ai-7.1b-20b": {
|
"together-ai-7.1b-20b": {
|
||||||
"max_tokens": 1000,
|
"max_tokens": 1000,
|
||||||
"input_cost_per_token": 0.0000004,
|
"input_cost_per_token": 0.0000004,
|
||||||
"output_cost_per_token": 0.0000004
|
"output_cost_per_token": 0.0000004,
|
||||||
|
"litellm_provider": "together_ai"
|
||||||
},
|
},
|
||||||
"together-ai-20.1b-40b": {
|
"together-ai-20.1b-40b": {
|
||||||
"input_cost_per_token": 0.000001,
|
"input_cost_per_token": 0.0000008,
|
||||||
"output_cost_per_token": 0.000001
|
"output_cost_per_token": 0.0000008,
|
||||||
|
"litellm_provider": "together_ai"
|
||||||
},
|
},
|
||||||
"together-ai-40.1b-70b": {
|
"together-ai-40.1b-70b": {
|
||||||
"input_cost_per_token": 0.000003,
|
"input_cost_per_token": 0.0000009,
|
||||||
"output_cost_per_token": 0.000003
|
"output_cost_per_token": 0.0000009,
|
||||||
|
"litellm_provider": "together_ai"
|
||||||
},
|
},
|
||||||
"ollama/llama2": {
|
"ollama/llama2": {
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
|
@ -578,10 +777,38 @@
|
||||||
"litellm_provider": "ollama",
|
"litellm_provider": "ollama",
|
||||||
"mode": "completion"
|
"mode": "completion"
|
||||||
},
|
},
|
||||||
|
"ollama/mistral": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "ollama",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"ollama/codellama": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "ollama",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"ollama/orca-mini": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "ollama",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
|
"ollama/vicuna": {
|
||||||
|
"max_tokens": 2048,
|
||||||
|
"input_cost_per_token": 0.0,
|
||||||
|
"output_cost_per_token": 0.0,
|
||||||
|
"litellm_provider": "ollama",
|
||||||
|
"mode": "completion"
|
||||||
|
},
|
||||||
"deepinfra/meta-llama/Llama-2-70b-chat-hf": {
|
"deepinfra/meta-llama/Llama-2-70b-chat-hf": {
|
||||||
"max_tokens": 6144,
|
"max_tokens": 4096,
|
||||||
"input_cost_per_token": 0.000001875,
|
"input_cost_per_token": 0.000000700,
|
||||||
"output_cost_per_token": 0.000001875,
|
"output_cost_per_token": 0.000000950,
|
||||||
"litellm_provider": "deepinfra",
|
"litellm_provider": "deepinfra",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
},
|
},
|
||||||
|
@ -619,5 +846,103 @@
|
||||||
"output_cost_per_token": 0.00000095,
|
"output_cost_per_token": 0.00000095,
|
||||||
"litellm_provider": "deepinfra",
|
"litellm_provider": "deepinfra",
|
||||||
"mode": "chat"
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/pplx-7b-chat": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/pplx-70b-chat": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/pplx-7b-online": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.0005,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/pplx-70b-online": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.0005,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-2-13b-chat": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/llama-2-70b-chat": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/mistral-7b-instruct": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"perplexity/replit-code-v1.5-3b": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.0000000,
|
||||||
|
"output_cost_per_token": 0.000000,
|
||||||
|
"litellm_provider": "perplexity",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.00000015,
|
||||||
|
"output_cost_per_token": 0.00000015,
|
||||||
|
"litellm_provider": "anyscale",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.00000015,
|
||||||
|
"output_cost_per_token": 0.00000015,
|
||||||
|
"litellm_provider": "anyscale",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000015,
|
||||||
|
"output_cost_per_token": 0.00000015,
|
||||||
|
"litellm_provider": "anyscale",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anyscale/meta-llama/Llama-2-13b-chat-hf": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.00000025,
|
||||||
|
"output_cost_per_token": 0.00000025,
|
||||||
|
"litellm_provider": "anyscale",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anyscale/meta-llama/Llama-2-70b-chat-hf": {
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "anyscale",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
"anyscale/codellama/CodeLlama-34b-Instruct-hf": {
|
||||||
|
"max_tokens": 16384,
|
||||||
|
"input_cost_per_token": 0.000001,
|
||||||
|
"output_cost_per_token": 0.000001,
|
||||||
|
"litellm_provider": "anyscale",
|
||||||
|
"mode": "chat"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
4
litellm/proxy/_experimental/post_call_rules.py
Normal file
4
litellm/proxy/_experimental/post_call_rules.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
def my_custom_rule(input): # receives the model response
|
||||||
|
# if len(input) < 5: # trigger fallback if the model response is too short
|
||||||
|
return False
|
||||||
|
return True
|
|
@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
from typing import Optional, List, Union, Dict, Literal
|
from typing import Optional, List, Union, Dict, Literal
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid, json
|
import uuid, json
|
||||||
|
|
||||||
|
class LiteLLMBase(BaseModel):
|
||||||
|
"""
|
||||||
|
Implements default functions, all pydantic objects should have.
|
||||||
|
"""
|
||||||
|
def json(self, **kwargs):
|
||||||
|
try:
|
||||||
|
return self.model_dump() # noqa
|
||||||
|
except:
|
||||||
|
# if using pydantic v1
|
||||||
|
return self.dict()
|
||||||
|
|
||||||
|
|
||||||
######### Request Class Definition ######
|
######### Request Class Definition ######
|
||||||
class ProxyChatCompletionRequest(BaseModel):
|
class ProxyChatCompletionRequest(LiteLLMBase):
|
||||||
model: str
|
model: str
|
||||||
messages: List[Dict[str, str]]
|
messages: List[Dict[str, str]]
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
|
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
|
||||||
class Config:
|
class Config:
|
||||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||||
|
|
||||||
class ModelInfoDelete(BaseModel):
|
class ModelInfoDelete(LiteLLMBase):
|
||||||
id: Optional[str]
|
id: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
class ModelInfo(BaseModel):
|
class ModelInfo(LiteLLMBase):
|
||||||
id: Optional[str]
|
id: Optional[str]
|
||||||
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
||||||
input_cost_per_token: Optional[float]
|
input_cost_per_token: Optional[float] = 0.0
|
||||||
output_cost_per_token: Optional[float]
|
output_cost_per_token: Optional[float] = 0.0
|
||||||
max_tokens: Optional[int]
|
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||||
|
|
||||||
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
||||||
# we look up the base model in model_prices_and_context_window.json
|
# we look up the base model in model_prices_and_context_window.json
|
||||||
|
@ -66,37 +79,40 @@ class ModelInfo(BaseModel):
|
||||||
extra = Extra.allow # Allow extra fields
|
extra = Extra.allow # Allow extra fields
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
# @root_validator(pre=True)
|
|
||||||
# def set_model_info(cls, values):
|
@root_validator(pre=True)
|
||||||
# if values.get("id") is None:
|
def set_model_info(cls, values):
|
||||||
# values.update({"id": str(uuid.uuid4())})
|
if values.get("id") is None:
|
||||||
# if values.get("mode") is None:
|
values.update({"id": str(uuid.uuid4())})
|
||||||
# values.update({"mode": str(uuid.uuid4())})
|
if values.get("mode") is None:
|
||||||
# return values
|
values.update({"mode": None})
|
||||||
|
if values.get("input_cost_per_token") is None:
|
||||||
|
values.update({"input_cost_per_token": None})
|
||||||
|
if values.get("output_cost_per_token") is None:
|
||||||
|
values.update({"output_cost_per_token": None})
|
||||||
|
if values.get("max_tokens") is None:
|
||||||
|
values.update({"max_tokens": None})
|
||||||
|
if values.get("base_model") is None:
|
||||||
|
values.update({"base_model": None})
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ModelParams(BaseModel):
|
class ModelParams(LiteLLMBase):
|
||||||
model_name: str
|
model_name: str
|
||||||
litellm_params: dict
|
litellm_params: dict
|
||||||
model_info: Optional[ModelInfo]=None
|
model_info: ModelInfo
|
||||||
|
|
||||||
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
|
|
||||||
# self.model_name = model_name
|
|
||||||
# self.litellm_params = litellm_params
|
|
||||||
# self.model_info = model_info if model_info else ModelInfo()
|
|
||||||
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
# @root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
# def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
# if values.get("model_info") is None:
|
if values.get("model_info") is None:
|
||||||
# values.update({"model_info": ModelInfo()})
|
values.update({"model_info": ModelInfo()})
|
||||||
# return values
|
return values
|
||||||
|
|
||||||
class GenerateKeyRequest(BaseModel):
|
class GenerateKeyRequest(LiteLLMBase):
|
||||||
duration: Optional[str] = "1h"
|
duration: Optional[str] = "1h"
|
||||||
models: Optional[list] = []
|
models: Optional[list] = []
|
||||||
aliases: Optional[dict] = {}
|
aliases: Optional[dict] = {}
|
||||||
|
@ -105,26 +121,32 @@ class GenerateKeyRequest(BaseModel):
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
|
|
||||||
def json(self, **kwargs):
|
class UpdateKeyRequest(LiteLLMBase):
|
||||||
try:
|
key: str
|
||||||
return self.model_dump() # noqa
|
duration: Optional[str] = None
|
||||||
except:
|
models: Optional[list] = None
|
||||||
# if using pydantic v1
|
aliases: Optional[dict] = None
|
||||||
return self.dict()
|
config: Optional[dict] = None
|
||||||
|
spend: Optional[float] = None
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
max_parallel_requests: Optional[int] = None
|
||||||
|
|
||||||
class GenerateKeyResponse(BaseModel):
|
class GenerateKeyResponse(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
expires: datetime
|
expires: datetime
|
||||||
user_id: str
|
user_id: str
|
||||||
|
|
||||||
class _DeleteKeyObject(BaseModel):
|
|
||||||
|
|
||||||
|
|
||||||
|
class _DeleteKeyObject(LiteLLMBase):
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
class DeleteKeyRequest(BaseModel):
|
class DeleteKeyRequest(LiteLLMBase):
|
||||||
keys: List[_DeleteKeyObject]
|
keys: List[_DeleteKeyObject]
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||||
"""
|
"""
|
||||||
Return the row in the db
|
Return the row in the db
|
||||||
"""
|
"""
|
||||||
|
@ -137,7 +159,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
duration: str = "1h"
|
duration: str = "1h"
|
||||||
|
|
||||||
class ConfigGeneralSettings(BaseModel):
|
class ConfigGeneralSettings(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
Documents all the fields supported by `general_settings` in config.yaml
|
Documents all the fields supported by `general_settings` in config.yaml
|
||||||
"""
|
"""
|
||||||
|
@ -153,10 +175,12 @@ class ConfigGeneralSettings(BaseModel):
|
||||||
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||||
|
|
||||||
|
|
||||||
class ConfigYAML(BaseModel):
|
class ConfigYAML(LiteLLMBase):
|
||||||
"""
|
"""
|
||||||
Documents all the fields supported by the config.yaml
|
Documents all the fields supported by the config.yaml
|
||||||
"""
|
"""
|
||||||
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
|
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
|
||||||
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
|
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
|
||||||
general_settings: Optional[ConfigGeneralSettings] = None
|
general_settings: Optional[ConfigGeneralSettings] = None
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
|
@ -1,3 +1,11 @@
|
||||||
|
import sys, os, traceback
|
||||||
|
|
||||||
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
import litellm
|
import litellm
|
||||||
import inspect
|
import inspect
|
||||||
|
@ -37,8 +45,11 @@ class MyCustomHandler(CustomLogger):
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print_verbose("On Success!")
|
print_verbose("On Success!")
|
||||||
|
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print_verbose(f"On Async Success!")
|
print_verbose(f"On Async Success!")
|
||||||
|
response_cost = litellm.completion_cost(completion_response=response_obj)
|
||||||
|
assert response_cost > 0.0
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
|
|
@ -69,7 +69,6 @@ async def _perform_health_check(model_list: list):
|
||||||
for model in model_list:
|
for model in model_list:
|
||||||
litellm_params = model["litellm_params"]
|
litellm_params = model["litellm_params"]
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
|
|
||||||
litellm_params["messages"] = _get_random_llm_message()
|
litellm_params["messages"] = _get_random_llm_message()
|
||||||
|
|
||||||
prepped_params.append(litellm_params)
|
prepped_params.append(litellm_params)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
@ -14,24 +15,28 @@ class MaxParallelRequestsHandler(CustomLogger):
|
||||||
print(print_statement) # noqa
|
print(print_statement) # noqa
|
||||||
|
|
||||||
|
|
||||||
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
|
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||||
|
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||||
|
api_key = user_api_key_dict.api_key
|
||||||
|
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if max_parallel_requests is None:
|
if max_parallel_requests is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
|
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||||
|
|
||||||
# CHECK IF REQUEST ALLOWED
|
# CHECK IF REQUEST ALLOWED
|
||||||
request_count_api_key = f"{api_key}_request_count"
|
request_count_api_key = f"{api_key}_request_count"
|
||||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
current = cache.get_cache(key=request_count_api_key)
|
||||||
self.print_verbose(f"current: {current}")
|
self.print_verbose(f"current: {current}")
|
||||||
if current is None:
|
if current is None:
|
||||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
cache.set_cache(request_count_api_key, 1)
|
||||||
elif int(current) < max_parallel_requests:
|
elif int(current) < max_parallel_requests:
|
||||||
# Increase count for this token
|
# Increase count for this token
|
||||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
cache.set_cache(request_count_api_key, int(current) + 1)
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
||||||
|
|
||||||
|
@ -55,16 +60,24 @@ class MaxParallelRequestsHandler(CustomLogger):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(e) # noqa
|
self.print_verbose(e) # noqa
|
||||||
|
|
||||||
async def async_log_failure_call(self, api_key, user_api_key_cache):
|
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception):
|
||||||
try:
|
try:
|
||||||
|
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||||
|
api_key = user_api_key_dict.api_key
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
request_count_api_key = f"{api_key}_request_count"
|
## decrement call count if call failed
|
||||||
# Decrease count for this token
|
if (hasattr(original_exception, "status_code")
|
||||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
and original_exception.status_code == 429
|
||||||
new_val = current - 1
|
and "Max parallel request limit reached" in str(original_exception)):
|
||||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
pass # ignore failed calls due to max limit being reached
|
||||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
else:
|
||||||
|
request_count_api_key = f"{api_key}_request_count"
|
||||||
|
# Decrease count for this token
|
||||||
|
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||||
|
new_val = current - 1
|
||||||
|
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||||
|
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
|
@ -3,6 +3,7 @@ import subprocess, traceback, json
|
||||||
import os, sys
|
import os, sys
|
||||||
import random, appdirs
|
import random, appdirs
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import importlib
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import operator
|
import operator
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
|
@ -76,13 +77,14 @@ def is_port_in_use(port):
|
||||||
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
|
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
|
||||||
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
|
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
|
||||||
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
||||||
|
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version')
|
||||||
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
|
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
|
||||||
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
|
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
|
||||||
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
|
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
|
||||||
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
|
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
|
||||||
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
|
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
|
||||||
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
||||||
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health):
|
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version):
|
||||||
global feature_telemetry
|
global feature_telemetry
|
||||||
args = locals()
|
args = locals()
|
||||||
if local:
|
if local:
|
||||||
|
@ -113,6 +115,10 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
||||||
except:
|
except:
|
||||||
raise Exception("LiteLLM: No logs saved!")
|
raise Exception("LiteLLM: No logs saved!")
|
||||||
return
|
return
|
||||||
|
if version == True:
|
||||||
|
pkg_version = importlib.metadata.version("litellm")
|
||||||
|
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n')
|
||||||
|
return
|
||||||
if model and "ollama" in model and api_base is None:
|
if model and "ollama" in model and api_base is None:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
if test_async is True:
|
if test_async is True:
|
||||||
|
|
|
@ -11,8 +11,10 @@ model_list:
|
||||||
output_cost_per_token: 0.00003
|
output_cost_per_token: 0.00003
|
||||||
max_tokens: 4096
|
max_tokens: 4096
|
||||||
base_model: gpt-3.5-turbo
|
base_model: gpt-3.5-turbo
|
||||||
|
- model_name: BEDROCK_GROUP
|
||||||
- model_name: openai-gpt-3.5
|
litellm_params:
|
||||||
|
model: bedrock/cohere.command-text-v14
|
||||||
|
- model_name: Azure OpenAI GPT-4 Canada-East (External)
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: gpt-3.5-turbo
|
model: gpt-3.5-turbo
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
@ -41,11 +43,12 @@ model_list:
|
||||||
mode: completion
|
mode: completion
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
# cache: True
|
||||||
# setting callback class
|
# setting callback class
|
||||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||||
model_group_alias_map: {"gpt-4": "openai-gpt-3.5"} # all requests with gpt-4 model_name, get sent to openai-gpt-3.5
|
|
||||||
|
|
||||||
|
|
||||||
general_settings:
|
general_settings:
|
||||||
|
|
||||||
|
environment_variables:
|
||||||
# otel: True # OpenTelemetry Logger
|
# otel: True # OpenTelemetry Logger
|
||||||
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||||
|
|
|
@ -195,8 +195,10 @@ prisma_client: Optional[PrismaClient] = None
|
||||||
user_api_key_cache = DualCache()
|
user_api_key_cache = DualCache()
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
use_background_health_checks = None
|
use_background_health_checks = None
|
||||||
|
use_queue = False
|
||||||
health_check_interval = None
|
health_check_interval = None
|
||||||
health_check_results = {}
|
health_check_results = {}
|
||||||
|
queue: List = []
|
||||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
|
@ -252,51 +254,58 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
if api_key is None: # only require api key if master key is set
|
if api_key is None: # only require api key if master key is set
|
||||||
raise Exception(f"No api key passed in.")
|
raise Exception(f"No api key passed in.")
|
||||||
|
|
||||||
route = request.url.path
|
route: str = request.url.path
|
||||||
|
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return UserAPIKeyAuth(api_key=master_key)
|
return UserAPIKeyAuth(api_key=master_key)
|
||||||
|
|
||||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
if route.startswith("/key/") and not is_master_key_valid:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
|
||||||
|
|
||||||
if prisma_client:
|
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
## check for cache hit (In-Memory Cache)
|
raise Exception("No connected db.")
|
||||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
|
||||||
print(f"valid_token from cache: {valid_token}")
|
## check for cache hit (In-Memory Cache)
|
||||||
if valid_token is None:
|
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
## check db
|
print(f"valid_token from cache: {valid_token}")
|
||||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
if valid_token is None:
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
## check db
|
||||||
elif valid_token is not None:
|
print(f"api key: {api_key}")
|
||||||
print(f"API Key Cache Hit!")
|
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||||
if valid_token:
|
print(f"valid token from prisma: {valid_token}")
|
||||||
litellm.model_alias_map = valid_token.aliases
|
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||||
config = valid_token.config
|
elif valid_token is not None:
|
||||||
if config != {}:
|
print(f"API Key Cache Hit!")
|
||||||
model_list = config.get("model_list", [])
|
if valid_token:
|
||||||
llm_model_list = model_list
|
litellm.model_alias_map = valid_token.aliases
|
||||||
print("\n new llm router model list", llm_model_list)
|
config = valid_token.config
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
if config != {}:
|
||||||
api_key = valid_token.token
|
model_list = config.get("model_list", [])
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
llm_model_list = model_list
|
||||||
valid_token_dict.pop("token", None)
|
print("\n new llm router model list", llm_model_list)
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
else:
|
|
||||||
data = await request.json()
|
|
||||||
model = data.get("model", None)
|
|
||||||
if model in litellm.model_alias_map:
|
|
||||||
model = litellm.model_alias_map[model]
|
|
||||||
if model and model not in valid_token.models:
|
|
||||||
raise Exception(f"Token not allowed to access model")
|
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
data = {} # Provide a default value, such as an empty dictionary
|
||||||
|
model = data.get("model", None)
|
||||||
|
if model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[model]
|
||||||
|
if model and model not in valid_token.models:
|
||||||
|
raise Exception(f"Token not allowed to access model")
|
||||||
|
api_key = valid_token.token
|
||||||
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
|
valid_token_dict.pop("token", None)
|
||||||
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid token")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An exception occurred - {traceback.format_exc()}")
|
print(f"An exception occurred - {traceback.format_exc()}")
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -310,24 +319,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup(database_url: Optional[str]):
|
||||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||||
|
|
||||||
proxy_logging_obj._init_litellm_callbacks()
|
|
||||||
if database_url is not None:
|
if database_url is not None:
|
||||||
try:
|
try:
|
||||||
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
||||||
|
|
||||||
def celery_setup(use_queue: bool):
|
|
||||||
global celery_fn, celery_app_conn, async_result
|
|
||||||
if use_queue:
|
|
||||||
from litellm.proxy.queue.celery_worker import start_worker
|
|
||||||
from litellm.proxy.queue.celery_app import celery_app, process_job
|
|
||||||
from celery.result import AsyncResult
|
|
||||||
start_worker(os.getcwd())
|
|
||||||
celery_fn = process_job
|
|
||||||
async_result = AsyncResult
|
|
||||||
celery_app_conn = celery_app
|
|
||||||
|
|
||||||
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||||
if use_azure_key_vault is False:
|
if use_azure_key_vault is False:
|
||||||
return
|
return
|
||||||
|
@ -380,30 +377,14 @@ async def track_cost_callback(
|
||||||
if "complete_streaming_response" in kwargs:
|
if "complete_streaming_response" in kwargs:
|
||||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
completion_response=kwargs["complete_streaming_response"]
|
completion_response=kwargs["complete_streaming_response"]
|
||||||
input_text = kwargs["messages"]
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
output_text = completion_response["choices"][0]["message"]["content"]
|
|
||||||
response_cost = litellm.completion_cost(
|
|
||||||
model = kwargs["model"],
|
|
||||||
messages = input_text,
|
|
||||||
completion=output_text
|
|
||||||
)
|
|
||||||
print("streaming response_cost", response_cost)
|
print("streaming response_cost", response_cost)
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||||
elif kwargs["stream"] == False: # for non streaming responses
|
elif kwargs["stream"] == False: # for non streaming responses
|
||||||
input_text = kwargs.get("messages", "")
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
print(f"type of input_text: {type(input_text)}")
|
|
||||||
if isinstance(input_text, list):
|
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
|
||||||
elif isinstance(input_text, str):
|
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
|
||||||
print(f"received completion response: {completion_response}")
|
|
||||||
|
|
||||||
print(f"regular response_cost: {response_cost}")
|
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -459,7 +440,7 @@ async def _run_background_health_check():
|
||||||
await asyncio.sleep(health_check_interval)
|
await asyncio.sleep(health_check_interval)
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||||
config = {}
|
config = {}
|
||||||
try:
|
try:
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
|
@ -504,6 +485,18 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||||
|
|
||||||
|
cache_params = {
|
||||||
|
"type": cache_type,
|
||||||
|
"host": cache_host,
|
||||||
|
"port": cache_port,
|
||||||
|
"password": cache_password
|
||||||
|
}
|
||||||
|
|
||||||
|
if "cache_params" in litellm_settings:
|
||||||
|
cache_params_in_config = litellm_settings["cache_params"]
|
||||||
|
# overwrie cache_params with cache_params_in_config
|
||||||
|
cache_params.update(cache_params_in_config)
|
||||||
|
|
||||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||||
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
||||||
print(f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}")
|
print(f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}")
|
||||||
|
@ -513,15 +506,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
|
|
||||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||||
litellm.cache = Cache(
|
litellm.cache = Cache(
|
||||||
type=cache_type,
|
**cache_params
|
||||||
host=cache_host,
|
|
||||||
port=cache_port,
|
|
||||||
password=cache_password
|
|
||||||
)
|
)
|
||||||
print(f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}")
|
print(f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}")
|
||||||
elif key == "callbacks":
|
elif key == "callbacks":
|
||||||
litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)]
|
litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)]
|
||||||
print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}")
|
print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}")
|
||||||
|
elif key == "post_call_rules":
|
||||||
|
litellm.post_call_rules = [get_instance_fn(value=value, config_file_path=config_file_path)]
|
||||||
|
print(f"litellm.post_call_rules: {litellm.post_call_rules}")
|
||||||
elif key == "success_callback":
|
elif key == "success_callback":
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
|
||||||
|
@ -533,10 +526,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||||
else:
|
else:
|
||||||
litellm.success_callback.append(callback)
|
litellm.success_callback.append(callback)
|
||||||
if callback == "traceloop":
|
|
||||||
from traceloop.sdk import Traceloop
|
|
||||||
print_verbose(f"{blue_color_code} Initializing Traceloop SDK - \nRunning:`Traceloop.init(app_name='Litellm-Server', disable_batch=True)`")
|
|
||||||
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
|
|
||||||
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}")
|
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}")
|
||||||
elif key == "failure_callback":
|
elif key == "failure_callback":
|
||||||
litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
|
@ -550,6 +539,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
else:
|
else:
|
||||||
litellm.failure_callback.append(callback)
|
litellm.failure_callback.append(callback)
|
||||||
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}")
|
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}")
|
||||||
|
elif key == "cache_params":
|
||||||
|
# this is set in the cache branch
|
||||||
|
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
setattr(litellm, key, value)
|
setattr(litellm, key, value)
|
||||||
|
|
||||||
|
@ -572,7 +565,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
cost_tracking()
|
cost_tracking()
|
||||||
### START REDIS QUEUE ###
|
### START REDIS QUEUE ###
|
||||||
use_queue = general_settings.get("use_queue", False)
|
use_queue = general_settings.get("use_queue", False)
|
||||||
celery_setup(use_queue=use_queue)
|
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get("master_key", None)
|
master_key = general_settings.get("master_key", None)
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
|
@ -683,6 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
||||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||||
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def delete_verification_token(tokens: List):
|
async def delete_verification_token(tokens: List):
|
||||||
global prisma_client
|
global prisma_client
|
||||||
try:
|
try:
|
||||||
|
@ -761,8 +755,6 @@ def initialize(
|
||||||
if max_budget: # litellm-specific param
|
if max_budget: # litellm-specific param
|
||||||
litellm.max_budget = max_budget
|
litellm.max_budget = max_budget
|
||||||
dynamic_config["general"]["max_budget"] = max_budget
|
dynamic_config["general"]["max_budget"] = max_budget
|
||||||
if use_queue:
|
|
||||||
celery_setup(use_queue=use_queue)
|
|
||||||
if experimental:
|
if experimental:
|
||||||
pass
|
pass
|
||||||
user_telemetry = telemetry
|
user_telemetry = telemetry
|
||||||
|
@ -798,48 +790,12 @@ def data_generator(response):
|
||||||
async def async_data_generator(response, user_api_key_dict):
|
async def async_data_generator(response, user_api_key_dict):
|
||||||
print_verbose("inside generator")
|
print_verbose("inside generator")
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
# try:
|
|
||||||
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
print_verbose(f"returned chunk: {chunk}")
|
print_verbose(f"returned chunk: {chunk}")
|
||||||
try:
|
try:
|
||||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||||
except:
|
except:
|
||||||
yield f"data: {json.dumps(chunk)}\n\n"
|
yield f"data: {json.dumps(chunk)}\n\n"
|
||||||
|
|
||||||
def litellm_completion(*args, **kwargs):
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
|
||||||
call_type = kwargs.pop("call_type")
|
|
||||||
# override with user settings, these are params passed via cli
|
|
||||||
if user_temperature:
|
|
||||||
kwargs["temperature"] = user_temperature
|
|
||||||
if user_request_timeout:
|
|
||||||
kwargs["request_timeout"] = user_request_timeout
|
|
||||||
if user_max_tokens:
|
|
||||||
kwargs["max_tokens"] = user_max_tokens
|
|
||||||
if user_api_base:
|
|
||||||
kwargs["api_base"] = user_api_base
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
|
||||||
try:
|
|
||||||
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
|
|
||||||
if call_type == "chat_completion":
|
|
||||||
response = llm_router.completion(*args, **kwargs)
|
|
||||||
elif call_type == "text_completion":
|
|
||||||
response = llm_router.text_completion(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
if call_type == "chat_completion":
|
|
||||||
response = litellm.completion(*args, **kwargs)
|
|
||||||
elif call_type == "text_completion":
|
|
||||||
response = litellm.text_completion(*args, **kwargs)
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
|
||||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
|
||||||
return response
|
|
||||||
|
|
||||||
def get_litellm_model_info(model: dict = {}):
|
def get_litellm_model_info(model: dict = {}):
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||||
|
@ -870,6 +826,8 @@ async def startup_event():
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
|
|
||||||
|
|
||||||
|
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||||
|
|
||||||
if use_background_health_checks:
|
if use_background_health_checks:
|
||||||
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
||||||
|
|
||||||
|
@ -881,16 +839,6 @@ async def startup_event():
|
||||||
# add master key to db
|
# add master key to db
|
||||||
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||||
|
|
||||||
@router.on_event("shutdown")
|
|
||||||
async def shutdown_event():
|
|
||||||
global prisma_client, master_key, user_custom_auth
|
|
||||||
if prisma_client:
|
|
||||||
print("Disconnecting from Prisma")
|
|
||||||
await prisma_client.disconnect()
|
|
||||||
|
|
||||||
## RESET CUSTOM VARIABLES ##
|
|
||||||
master_key = None
|
|
||||||
user_custom_auth = None
|
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
@ -929,7 +877,8 @@ def model_list():
|
||||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -938,7 +887,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
|
|
||||||
data["user"] = user_api_key_dict.user_id
|
data["user"] = data.get("user", user_api_key_dict.user_id)
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
@ -947,17 +896,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
)
|
)
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
data["call_type"] = "text_completion"
|
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
|
|
||||||
return litellm_completion(
|
# override with user settings, these are params passed via cli
|
||||||
**data
|
if user_temperature:
|
||||||
)
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
|
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||||
|
|
||||||
|
### ROUTE THE REQUEST ###
|
||||||
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.atext_completion(**data, specific_deployment = True)
|
||||||
|
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||||
|
response = await llm_router.atext_completion(**data)
|
||||||
|
else: # router is not set
|
||||||
|
response = await litellm.atext_completion(**data)
|
||||||
|
|
||||||
|
print(f"final response: {response}")
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||||
|
|
||||||
|
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||||
|
traceback.print_exc()
|
||||||
error_traceback = traceback.format_exc()
|
error_traceback = traceback.format_exc()
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
try:
|
try:
|
||||||
|
@ -995,7 +971,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
)
|
)
|
||||||
|
|
||||||
# users can pass in 'user' param to /chat/completions. Don't override it
|
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||||
if data.get("user", None) is None:
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
# if users are using user_api_key_auth, set `user` in `data`
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
data["user"] = user_api_key_dict.user_id
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
@ -1027,7 +1003,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
response = await llm_router.acompletion(**data)
|
response = await llm_router.acompletion(**data)
|
||||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||||
response = await llm_router.acompletion(**data, specific_deployment = True)
|
response = await llm_router.acompletion(**data, specific_deployment = True)
|
||||||
elif llm_router is not None and litellm.model_group_alias_map is not None and data["model"] in litellm.model_group_alias_map: # model set in model_group_alias_map
|
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||||
response = await llm_router.acompletion(**data)
|
response = await llm_router.acompletion(**data)
|
||||||
else: # router is not set
|
else: # router is not set
|
||||||
response = await litellm.acompletion(**data)
|
response = await litellm.acompletion(**data)
|
||||||
|
@ -1088,7 +1064,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
"body": copy.copy(data) # use copy instead of deepcopy
|
"body": copy.copy(data) # use copy instead of deepcopy
|
||||||
}
|
}
|
||||||
|
|
||||||
data["user"] = user_api_key_dict.user_id
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("embedding_model", None) # server default
|
general_settings.get("embedding_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
@ -1098,10 +1076,11 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
print(f"received data: {data['input']}")
|
|
||||||
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
|
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
|
||||||
# check if non-openai/azure model called - e.g. for langchain integration
|
# check if non-openai/azure model called - e.g. for langchain integration
|
||||||
if llm_model_list is not None and data["model"] in router_model_names:
|
if llm_model_list is not None and data["model"] in router_model_names:
|
||||||
|
@ -1119,12 +1098,13 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
|
|
||||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
response = await llm_router.aembedding(**data)
|
response = await llm_router.aembedding(**data)
|
||||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||||
response = await llm_router.aembedding(**data, specific_deployment = True)
|
response = await llm_router.aembedding(**data, specific_deployment = True)
|
||||||
|
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||||
|
response = await llm_router.aembedding(**data) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||||
else:
|
else:
|
||||||
response = await litellm.aembedding(**data)
|
response = await litellm.aembedding(**data)
|
||||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||||
|
@ -1133,7 +1113,19 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
raise e
|
if isinstance(e, HTTPException):
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
error_traceback = traceback.format_exc()
|
||||||
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
|
try:
|
||||||
|
status = e.status_code # type: ignore
|
||||||
|
except:
|
||||||
|
status = 500
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status,
|
||||||
|
detail=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
#### KEY MANAGEMENT ####
|
#### KEY MANAGEMENT ####
|
||||||
|
|
||||||
|
@ -1162,6 +1154,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
||||||
response = await generate_key_helper_fn(**data_json)
|
response = await generate_key_helper_fn(**data_json)
|
||||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||||
|
|
||||||
|
@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def update_key_fn(request: Request, data: UpdateKeyRequest):
|
||||||
|
"""
|
||||||
|
Update an existing key
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
try:
|
||||||
|
data_json: dict = data.json()
|
||||||
|
key = data_json.pop("key")
|
||||||
|
# get the row from db
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
|
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
||||||
|
print(f"non_default_values: {non_default_values}")
|
||||||
|
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
|
||||||
|
return {"key": key, **non_default_values}
|
||||||
|
# update based on remaining passed in values
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||||
try:
|
try:
|
||||||
|
@ -1207,10 +1223,12 @@ async def add_new_model(model_params: ModelParams):
|
||||||
|
|
||||||
print_verbose(f"Loaded config: {config}")
|
print_verbose(f"Loaded config: {config}")
|
||||||
# Add the new model to the config
|
# Add the new model to the config
|
||||||
|
model_info = model_params.model_info.json()
|
||||||
|
model_info = {k: v for k, v in model_info.items() if v is not None}
|
||||||
config['model_list'].append({
|
config['model_list'].append({
|
||||||
'model_name': model_params.model_name,
|
'model_name': model_params.model_name,
|
||||||
'litellm_params': model_params.litellm_params,
|
'litellm_params': model_params.litellm_params,
|
||||||
'model_info': model_params.model_info
|
'model_info': model_info
|
||||||
})
|
})
|
||||||
|
|
||||||
# Save the updated config
|
# Save the updated config
|
||||||
|
@ -1228,7 +1246,7 @@ async def add_new_model(model_params: ModelParams):
|
||||||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
||||||
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def model_info_v1(request: Request):
|
async def model_info_v1(request: Request):
|
||||||
global llm_model_list, general_settings, user_config_file_path
|
global llm_model_list, general_settings, user_config_file_path
|
||||||
# Load existing config
|
# Load existing config
|
||||||
|
@ -1256,7 +1274,7 @@ async def model_info_v1(request: Request):
|
||||||
|
|
||||||
|
|
||||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
||||||
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||||
async def model_info(request: Request):
|
async def model_info(request: Request):
|
||||||
global llm_model_list, general_settings, user_config_file_path
|
global llm_model_list, general_settings, user_config_file_path
|
||||||
# Load existing config
|
# Load existing config
|
||||||
|
@ -1341,46 +1359,107 @@ async def delete_model(model_info: ModelInfoDelete):
|
||||||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||||
|
|
||||||
#### EXPERIMENTAL QUEUING ####
|
#### EXPERIMENTAL QUEUING ####
|
||||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
async def _litellm_chat_completions_worker(data, user_api_key_dict):
|
||||||
async def async_queue_request(request: Request):
|
"""
|
||||||
global celery_fn, llm_model_list
|
worker to make litellm completions calls
|
||||||
if celery_fn is not None:
|
"""
|
||||||
body = await request.body()
|
while True:
|
||||||
body_str = body.decode()
|
|
||||||
try:
|
try:
|
||||||
data = ast.literal_eval(body_str)
|
### CALL HOOKS ### - modify incoming data before calling the model
|
||||||
except:
|
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||||
data = json.loads(body_str)
|
|
||||||
|
print(f"_litellm_chat_completions_worker started")
|
||||||
|
### ROUTE THE REQUEST ###
|
||||||
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||||
|
response = await llm_router.acompletion(**data)
|
||||||
|
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||||
|
response = await llm_router.acompletion(**data, specific_deployment = True)
|
||||||
|
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||||
|
response = await llm_router.acompletion(**data)
|
||||||
|
else: # router is not set
|
||||||
|
response = await litellm.acompletion(**data)
|
||||||
|
|
||||||
|
print(f"final response: {response}")
|
||||||
|
return response
|
||||||
|
except HTTPException as e:
|
||||||
|
print(f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}")
|
||||||
|
if e.status_code == 429 and "Max parallel request limit reached" in e.detail:
|
||||||
|
print(f"Max parallel request limit reached!")
|
||||||
|
timeout = litellm._calculate_retry_after(remaining_retries=3, max_retries=3, min_timeout=1)
|
||||||
|
await asyncio.sleep(timeout)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/queue/chat/completions", tags=["experimental"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
async def async_queue_request(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
|
global general_settings, user_debug, proxy_logging_obj
|
||||||
|
"""
|
||||||
|
v2 attempt at a background worker to handle queuing.
|
||||||
|
|
||||||
|
Just supports /chat/completion calls currently.
|
||||||
|
|
||||||
|
Now using a FastAPI background task + /chat/completions compatible endpoint
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = {}
|
||||||
|
data = await request.json() # type: ignore
|
||||||
|
|
||||||
|
# Include original request and headers in the data
|
||||||
|
data["proxy_server_request"] = {
|
||||||
|
"url": str(request.url),
|
||||||
|
"method": request.method,
|
||||||
|
"headers": dict(request.headers),
|
||||||
|
"body": copy.copy(data) # use copy instead of deepcopy
|
||||||
|
}
|
||||||
|
|
||||||
|
print_verbose(f"receiving data: {data}")
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
or model # for azure deployments
|
||||||
or data["model"] # default passed in http request
|
or data["model"] # default passed in http request
|
||||||
)
|
)
|
||||||
data["llm_model_list"] = llm_model_list
|
|
||||||
print(f"data: {data}")
|
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||||
job = celery_fn.apply_async(kwargs=data)
|
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
else:
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
|
if "metadata" in data:
|
||||||
|
print(f'received metadata: {data["metadata"]}')
|
||||||
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
else:
|
||||||
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
|
data["metadata"]["headers"] = dict(request.headers)
|
||||||
|
|
||||||
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
|
# override with user settings, these are params passed via cli
|
||||||
|
if user_temperature:
|
||||||
|
data["temperature"] = user_temperature
|
||||||
|
if user_request_timeout:
|
||||||
|
data["request_timeout"] = user_request_timeout
|
||||||
|
if user_max_tokens:
|
||||||
|
data["max_tokens"] = user_max_tokens
|
||||||
|
if user_api_base:
|
||||||
|
data["api_base"] = user_api_base
|
||||||
|
|
||||||
|
response = await asyncio.wait_for(_litellm_chat_completions_worker(data=data, user_api_key_dict=user_api_key_dict), timeout=litellm.request_timeout)
|
||||||
|
|
||||||
|
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||||
|
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||||
|
|
||||||
|
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
detail={"error": "Queue not initialized"},
|
detail={"error": str(e)},
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
|
|
||||||
async def async_queue_response(request: Request, task_id: str):
|
|
||||||
global celery_app_conn, async_result
|
|
||||||
try:
|
|
||||||
if celery_app_conn is not None and async_result is not None:
|
|
||||||
job = async_result(task_id, app=celery_app_conn)
|
|
||||||
if job.ready():
|
|
||||||
return {"status": "finished", "result": job.result}
|
|
||||||
else:
|
|
||||||
return {'status': 'queued'}
|
|
||||||
else:
|
|
||||||
raise Exception()
|
|
||||||
except Exception as e:
|
|
||||||
return {"status": "finished", "result": str(e)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def retrieve_server_log(request: Request):
|
async def retrieve_server_log(request: Request):
|
||||||
|
@ -1411,8 +1490,18 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
|
||||||
return {"hello": "world"}
|
return {"hello": "world"}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/test")
|
@router.get("/test", tags=["health"])
|
||||||
async def test_endpoint(request: Request):
|
async def test_endpoint(request: Request):
|
||||||
|
"""
|
||||||
|
A test endpoint that pings the proxy server to check if it's healthy.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
request (Request): The incoming request.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the route of the request URL.
|
||||||
|
"""
|
||||||
|
# ping the proxy server to check if its healthy
|
||||||
return {"route": request.url.path}
|
return {"route": request.url.path}
|
||||||
|
|
||||||
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||||
|
@ -1470,4 +1559,27 @@ async def get_routes():
|
||||||
return {"routes": routes}
|
return {"routes": routes}
|
||||||
|
|
||||||
|
|
||||||
|
@router.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
global prisma_client, master_key, user_custom_auth
|
||||||
|
if prisma_client:
|
||||||
|
print("Disconnecting from Prisma")
|
||||||
|
await prisma_client.disconnect()
|
||||||
|
|
||||||
|
## RESET CUSTOM VARIABLES ##
|
||||||
|
cleanup_router_config_variables()
|
||||||
|
|
||||||
|
def cleanup_router_config_variables():
|
||||||
|
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
|
||||||
|
|
||||||
|
# Set all variables to None
|
||||||
|
master_key = None
|
||||||
|
user_config_file_path = None
|
||||||
|
otel_logging = None
|
||||||
|
user_custom_auth = None
|
||||||
|
user_custom_auth_path = None
|
||||||
|
use_background_health_checks = None
|
||||||
|
health_check_interval = None
|
||||||
|
|
||||||
|
|
||||||
app.include_router(router)
|
app.include_router(router)
|
||||||
|
|
|
@ -1,23 +0,0 @@
|
||||||
import openai
|
|
||||||
client = openai.OpenAI(
|
|
||||||
api_key="anything",
|
|
||||||
# base_url="http://0.0.0.0:8000",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# request sent to model set on litellm proxy, `litellm --model`
|
|
||||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "this is a test request, write a short poem"
|
|
||||||
},
|
|
||||||
])
|
|
||||||
|
|
||||||
print(response)
|
|
||||||
# except openai.APITimeoutError:
|
|
||||||
# print("Got openai Timeout Exception. Good job. The proxy mapped to OpenAI exceptions")
|
|
||||||
except Exception as e:
|
|
||||||
print("\n the proxy did not map to OpenAI exception. Instead got", e)
|
|
||||||
print(e.type) # type: ignore
|
|
||||||
print(e.message) # type: ignore
|
|
||||||
print(e.code) # type: ignore
|
|
|
@ -1,13 +1,13 @@
|
||||||
from typing import Optional, List, Any, Literal
|
from typing import Optional, List, Any, Literal
|
||||||
import os, subprocess, hashlib, importlib, asyncio
|
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||||
import litellm, backoff
|
import litellm, backoff
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
if litellm.set_verbose:
|
if litellm.set_verbose:
|
||||||
print(print_statement) # noqa
|
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||||
### LOGGING ###
|
### LOGGING ###
|
||||||
class ProxyLogging:
|
class ProxyLogging:
|
||||||
"""
|
"""
|
||||||
|
@ -26,7 +26,7 @@ class ProxyLogging:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_litellm_callbacks(self):
|
def _init_litellm_callbacks(self):
|
||||||
|
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||||
for callback in litellm.callbacks:
|
for callback in litellm.callbacks:
|
||||||
if callback not in litellm.input_callback:
|
if callback not in litellm.input_callback:
|
||||||
|
@ -65,17 +65,13 @@ class ProxyLogging:
|
||||||
2. /embeddings
|
2. /embeddings
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self.call_details["data"] = data
|
for callback in litellm.callbacks:
|
||||||
self.call_details["call_type"] = call_type
|
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
|
||||||
|
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
|
||||||
## check if max parallel requests set
|
if response is not None:
|
||||||
if user_api_key_dict.max_parallel_requests is not None:
|
data = response
|
||||||
## if set, check if request allowed
|
|
||||||
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
|
|
||||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
|
||||||
api_key=user_api_key_dict.api_key,
|
|
||||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
|
||||||
|
|
||||||
|
print_verbose(f'final data being sent to {call_type} call: {data}')
|
||||||
return data
|
return data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
@ -103,17 +99,13 @@ class ProxyLogging:
|
||||||
1. /chat/completions
|
1. /chat/completions
|
||||||
2. /embeddings
|
2. /embeddings
|
||||||
"""
|
"""
|
||||||
# check if max parallel requests set
|
|
||||||
if user_api_key_dict.max_parallel_requests is not None:
|
for callback in litellm.callbacks:
|
||||||
## decrement call count if call failed
|
try:
|
||||||
if (hasattr(original_exception, "status_code")
|
if isinstance(callback, CustomLogger):
|
||||||
and original_exception.status_code == 429
|
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
|
||||||
and "Max parallel request limit reached" in str(original_exception)):
|
except Exception as e:
|
||||||
pass # ignore failed calls due to max limit being reached
|
raise e
|
||||||
else:
|
|
||||||
await self.max_parallel_request_limiter.async_log_failure_call(
|
|
||||||
api_key=user_api_key_dict.api_key,
|
|
||||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@ -165,19 +157,20 @@ class PrismaClient:
|
||||||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||||
try:
|
try:
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
|
hashed_token = token
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
if expires:
|
if expires:
|
||||||
response = await self.db.litellm_verificationtoken.find_first(
|
response = await self.db.litellm_verificationtoken.find_first(
|
||||||
where={
|
where={
|
||||||
"token": token,
|
"token": hashed_token,
|
||||||
"expires": {"gte": expires} # Check if the token is not expired
|
"expires": {"gte": expires} # Check if the token is not expired
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
response = await self.db.litellm_verificationtoken.find_unique(
|
response = await self.db.litellm_verificationtoken.find_unique(
|
||||||
where={
|
where={
|
||||||
"token": token
|
"token": hashed_token
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
@ -200,18 +193,18 @@ class PrismaClient:
|
||||||
try:
|
try:
|
||||||
token = data["token"]
|
token = data["token"]
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
data["token"] = hashed_token
|
db_data = copy.deepcopy(data)
|
||||||
|
db_data["token"] = hashed_token
|
||||||
|
|
||||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||||
where={
|
where={
|
||||||
'token': hashed_token,
|
'token': hashed_token,
|
||||||
},
|
},
|
||||||
data={
|
data={
|
||||||
"create": {**data}, #type: ignore
|
"create": {**db_data}, #type: ignore
|
||||||
"update": {} # don't do anything if it already exists
|
"update": {} # don't do anything if it already exists
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return new_verification_token
|
return new_verification_token
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
|
@ -235,15 +228,16 @@ class PrismaClient:
|
||||||
if token.startswith("sk-"):
|
if token.startswith("sk-"):
|
||||||
token = self.hash_token(token=token)
|
token = self.hash_token(token=token)
|
||||||
|
|
||||||
data["token"] = token
|
db_data = copy.deepcopy(data)
|
||||||
|
db_data["token"] = token
|
||||||
response = await self.db.litellm_verificationtoken.update(
|
response = await self.db.litellm_verificationtoken.update(
|
||||||
where={
|
where={
|
||||||
"token": token
|
"token": token
|
||||||
},
|
},
|
||||||
data={**data} # type: ignore
|
data={**db_data} # type: ignore
|
||||||
)
|
)
|
||||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||||
return {"token": token, "data": data}
|
return {"token": token, "data": db_data}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
#
|
#
|
||||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||||
|
|
||||||
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union, Literal, Any
|
from typing import Dict, List, Optional, Union, Literal, Any
|
||||||
import random, threading, time, traceback, uuid
|
import random, threading, time, traceback, uuid
|
||||||
|
@ -17,6 +18,7 @@ import inspect, concurrent
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||||
|
import copy
|
||||||
class Router:
|
class Router:
|
||||||
"""
|
"""
|
||||||
Example usage:
|
Example usage:
|
||||||
|
@ -68,6 +70,7 @@ class Router:
|
||||||
redis_password: Optional[str] = None,
|
redis_password: Optional[str] = None,
|
||||||
cache_responses: Optional[bool] = False,
|
cache_responses: Optional[bool] = False,
|
||||||
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||||
|
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
|
||||||
## RELIABILITY ##
|
## RELIABILITY ##
|
||||||
num_retries: int = 0,
|
num_retries: int = 0,
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
@ -76,11 +79,13 @@ class Router:
|
||||||
fallbacks: List = [],
|
fallbacks: List = [],
|
||||||
allowed_fails: Optional[int] = None,
|
allowed_fails: Optional[int] = None,
|
||||||
context_window_fallbacks: List = [],
|
context_window_fallbacks: List = [],
|
||||||
|
model_group_alias: Optional[dict] = {},
|
||||||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||||
|
|
||||||
self.set_verbose = set_verbose
|
self.set_verbose = set_verbose
|
||||||
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
|
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
if model_list:
|
if model_list:
|
||||||
|
model_list = copy.deepcopy(model_list)
|
||||||
self.set_model_list(model_list)
|
self.set_model_list(model_list)
|
||||||
self.healthy_deployments: List = self.model_list
|
self.healthy_deployments: List = self.model_list
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
|
@ -99,6 +104,7 @@ class Router:
|
||||||
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
|
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
|
||||||
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
|
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
|
||||||
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
|
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
|
||||||
|
self.model_group_alias: dict = model_group_alias or {} # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
||||||
|
|
||||||
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||||
self.chat = litellm.Chat(params=default_litellm_params)
|
self.chat = litellm.Chat(params=default_litellm_params)
|
||||||
|
@ -107,9 +113,10 @@ class Router:
|
||||||
self.default_litellm_params = default_litellm_params
|
self.default_litellm_params = default_litellm_params
|
||||||
self.default_litellm_params.setdefault("timeout", timeout)
|
self.default_litellm_params.setdefault("timeout", timeout)
|
||||||
self.default_litellm_params.setdefault("max_retries", 0)
|
self.default_litellm_params.setdefault("max_retries", 0)
|
||||||
|
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
|
||||||
|
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type = "local" # default to an in-memory cache
|
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
cache_config = {}
|
cache_config = {}
|
||||||
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
|
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
|
||||||
|
@ -133,7 +140,7 @@ class Router:
|
||||||
if cache_responses:
|
if cache_responses:
|
||||||
if litellm.cache is None:
|
if litellm.cache is None:
|
||||||
# the cache can be initialized on the proxy server. We should not overwrite it
|
# the cache can be initialized on the proxy server. We should not overwrite it
|
||||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
|
||||||
self.cache_responses = cache_responses
|
self.cache_responses = cache_responses
|
||||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||||
### ROUTING SETUP ###
|
### ROUTING SETUP ###
|
||||||
|
@ -198,19 +205,10 @@ class Router:
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
########## remove -ModelID-XXXX from model ##############
|
kwargs[k].update(v)
|
||||||
original_model_string = data["model"]
|
|
||||||
# Find the index of "ModelID" in the string
|
|
||||||
self.print_verbose(f"completion model: {original_model_string}")
|
|
||||||
index_of_model_id = original_model_string.find("-ModelID")
|
|
||||||
# Remove everything after "-ModelID" if it exists
|
|
||||||
if index_of_model_id != -1:
|
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
|
||||||
else:
|
|
||||||
data["model"] = original_model_string
|
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -241,31 +239,25 @@ class Router:
|
||||||
**kwargs):
|
**kwargs):
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
||||||
original_model_string = None # set a default for this variable
|
|
||||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
|
model_name = data["model"]
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
########## remove -ModelID-XXXX from model ##############
|
elif k == "metadata":
|
||||||
original_model_string = data["model"]
|
kwargs[k].update(v)
|
||||||
# Find the index of "ModelID" in the string
|
|
||||||
index_of_model_id = original_model_string.find("-ModelID")
|
|
||||||
# Remove everything after "-ModelID" if it exists
|
|
||||||
if index_of_model_id != -1:
|
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
|
||||||
else:
|
|
||||||
data["model"] = original_model_string
|
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
self.total_calls[original_model_string] +=1
|
self.total_calls[model_name] +=1
|
||||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
self.success_calls[original_model_string] +=1
|
self.success_calls[model_name] +=1
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if original_model_string is not None:
|
if model_name is not None:
|
||||||
self.fail_calls[original_model_string] +=1
|
self.fail_calls[model_name] +=1
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def text_completion(self,
|
def text_completion(self,
|
||||||
|
@ -283,8 +275,43 @@ class Router:
|
||||||
|
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
|
# call via litellm.completion()
|
||||||
|
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
if self.num_retries > 0:
|
||||||
|
kwargs["model"] = model
|
||||||
|
kwargs["messages"] = messages
|
||||||
|
kwargs["original_exception"] = e
|
||||||
|
kwargs["original_function"] = self.completion
|
||||||
|
return self.function_with_retries(**kwargs)
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
async def atext_completion(self,
|
||||||
|
model: str,
|
||||||
|
prompt: str,
|
||||||
|
is_retry: Optional[bool] = False,
|
||||||
|
is_fallback: Optional[bool] = False,
|
||||||
|
is_async: Optional[bool] = False,
|
||||||
|
**kwargs):
|
||||||
|
try:
|
||||||
|
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||||
|
messages=[{"role": "user", "content": prompt}]
|
||||||
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
|
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||||
|
|
||||||
|
data = deployment["litellm_params"].copy()
|
||||||
|
for k, v in self.default_litellm_params.items():
|
||||||
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
|
kwargs[k] = v
|
||||||
|
elif k == "metadata":
|
||||||
|
kwargs[k].update(v)
|
||||||
|
|
||||||
########## remove -ModelID-XXXX from model ##############
|
########## remove -ModelID-XXXX from model ##############
|
||||||
original_model_string = data["model"]
|
original_model_string = data["model"]
|
||||||
# Find the index of "ModelID" in the string
|
# Find the index of "ModelID" in the string
|
||||||
|
@ -294,8 +321,9 @@ class Router:
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
data["model"] = original_model_string[:index_of_model_id]
|
||||||
else:
|
else:
|
||||||
data["model"] = original_model_string
|
data["model"] = original_model_string
|
||||||
# call via litellm.completion()
|
# call via litellm.atext_completion()
|
||||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.num_retries > 0:
|
if self.num_retries > 0:
|
||||||
kwargs["model"] = model
|
kwargs["model"] = model
|
||||||
|
@ -313,21 +341,14 @@ class Router:
|
||||||
**kwargs) -> Union[List[float], None]:
|
**kwargs) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
kwargs.setdefault("model_info", {})
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
########## remove -ModelID-XXXX from model ##############
|
elif k == "metadata":
|
||||||
original_model_string = data["model"]
|
kwargs[k].update(v)
|
||||||
# Find the index of "ModelID" in the string
|
|
||||||
index_of_model_id = original_model_string.find("-ModelID")
|
|
||||||
# Remove everything after "-ModelID" if it exists
|
|
||||||
if index_of_model_id != -1:
|
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
|
||||||
else:
|
|
||||||
data["model"] = original_model_string
|
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||||
# call via litellm.embedding()
|
# call via litellm.embedding()
|
||||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
@ -339,21 +360,15 @@ class Router:
|
||||||
**kwargs) -> Union[List[float], None]:
|
**kwargs) -> Union[List[float], None]:
|
||||||
# pick the one that is available (lowest TPM/RPM)
|
# pick the one that is available (lowest TPM/RPM)
|
||||||
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]})
|
||||||
data = deployment["litellm_params"].copy()
|
data = deployment["litellm_params"].copy()
|
||||||
kwargs["model_info"] = deployment.get("model_info", {})
|
kwargs["model_info"] = deployment.get("model_info", {})
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if k not in data: # prioritize model-specific params > default router params
|
if k not in kwargs: # prioritize model-specific params > default router params
|
||||||
data[k] = v
|
kwargs[k] = v
|
||||||
########## remove -ModelID-XXXX from model ##############
|
elif k == "metadata":
|
||||||
original_model_string = data["model"]
|
kwargs[k].update(v)
|
||||||
# Find the index of "ModelID" in the string
|
|
||||||
index_of_model_id = original_model_string.find("-ModelID")
|
|
||||||
# Remove everything after "-ModelID" if it exists
|
|
||||||
if index_of_model_id != -1:
|
|
||||||
data["model"] = original_model_string[:index_of_model_id]
|
|
||||||
else:
|
|
||||||
data["model"] = original_model_string
|
|
||||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||||
|
|
||||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||||
|
@ -371,7 +386,7 @@ class Router:
|
||||||
self.print_verbose(f'Async Response: {response}')
|
self.print_verbose(f'Async Response: {response}')
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.print_verbose(f"An exception occurs: {e}")
|
self.print_verbose(f"An exception occurs: {e}\n\n Traceback{traceback.format_exc()}")
|
||||||
original_exception = e
|
original_exception = e
|
||||||
try:
|
try:
|
||||||
self.print_verbose(f"Trying to fallback b/w models")
|
self.print_verbose(f"Trying to fallback b/w models")
|
||||||
|
@ -637,9 +652,10 @@ class Router:
|
||||||
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
||||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||||
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
|
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
|
||||||
|
deployment_id = kwargs.get("litellm_params", {}).get("model_info").get("id")
|
||||||
|
self._set_cooldown_deployments(deployment_id) # setting deployment_id in cooldown deployments
|
||||||
if metadata:
|
if metadata:
|
||||||
deployment = metadata.get("deployment", None)
|
deployment = metadata.get("deployment", None)
|
||||||
self._set_cooldown_deployments(deployment)
|
|
||||||
deployment_exceptions = self.model_exception_map.get(deployment, [])
|
deployment_exceptions = self.model_exception_map.get(deployment, [])
|
||||||
deployment_exceptions.append(exception_str)
|
deployment_exceptions.append(exception_str)
|
||||||
self.model_exception_map[deployment] = deployment_exceptions
|
self.model_exception_map[deployment] = deployment_exceptions
|
||||||
|
@ -877,7 +893,7 @@ class Router:
|
||||||
return chosen_item
|
return chosen_item
|
||||||
|
|
||||||
def set_model_list(self, model_list: list):
|
def set_model_list(self, model_list: list):
|
||||||
self.model_list = model_list
|
self.model_list = copy.deepcopy(model_list)
|
||||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||||
import os
|
import os
|
||||||
for model in self.model_list:
|
for model in self.model_list:
|
||||||
|
@ -889,23 +905,26 @@ class Router:
|
||||||
model["model_info"] = model_info
|
model["model_info"] = model_info
|
||||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||||
if custom_llm_provider is None:
|
custom_llm_provider = custom_llm_provider or model_name.split("/",1)[0] or ""
|
||||||
custom_llm_provider = model_name.split("/",1)[0]
|
default_api_base = None
|
||||||
|
default_api_key = None
|
||||||
|
if custom_llm_provider in litellm.openai_compatible_providers:
|
||||||
|
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(model=model_name)
|
||||||
|
default_api_base = api_base
|
||||||
|
default_api_key = api_key
|
||||||
if (
|
if (
|
||||||
model_name in litellm.open_ai_chat_completion_models
|
model_name in litellm.open_ai_chat_completion_models
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
or custom_llm_provider == "deepinfra"
|
|
||||||
or custom_llm_provider == "perplexity"
|
|
||||||
or custom_llm_provider == "anyscale"
|
|
||||||
or custom_llm_provider == "openai"
|
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
|
or custom_llm_provider == "custom_openai"
|
||||||
|
or custom_llm_provider == "openai"
|
||||||
or "ft:gpt-3.5-turbo" in model_name
|
or "ft:gpt-3.5-turbo" in model_name
|
||||||
or model_name in litellm.open_ai_embedding_models
|
or model_name in litellm.open_ai_embedding_models
|
||||||
):
|
):
|
||||||
# glorified / complicated reading of configs
|
# glorified / complicated reading of configs
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||||
api_key = litellm_params.get("api_key")
|
api_key = litellm_params.get("api_key") or default_api_key
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
api_key = litellm.get_secret(api_key_env_name)
|
api_key = litellm.get_secret(api_key_env_name)
|
||||||
|
@ -913,7 +932,7 @@ class Router:
|
||||||
|
|
||||||
api_base = litellm_params.get("api_base")
|
api_base = litellm_params.get("api_base")
|
||||||
base_url = litellm_params.get("base_url")
|
base_url = litellm_params.get("base_url")
|
||||||
api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure
|
api_base = api_base or base_url or default_api_base # allow users to pass in `api_base` or `base_url` for azure
|
||||||
if api_base and api_base.startswith("os.environ/"):
|
if api_base and api_base.startswith("os.environ/"):
|
||||||
api_base_env_name = api_base.replace("os.environ/", "")
|
api_base_env_name = api_base.replace("os.environ/", "")
|
||||||
api_base = litellm.get_secret(api_base_env_name)
|
api_base = litellm.get_secret(api_base_env_name)
|
||||||
|
@ -1049,12 +1068,6 @@ class Router:
|
||||||
|
|
||||||
############ End of initializing Clients for OpenAI/Azure ###################
|
############ End of initializing Clients for OpenAI/Azure ###################
|
||||||
self.deployment_names.append(model["litellm_params"]["model"])
|
self.deployment_names.append(model["litellm_params"]["model"])
|
||||||
model_id = ""
|
|
||||||
for key in model["litellm_params"]:
|
|
||||||
if key != "api_key" and key != "metadata":
|
|
||||||
model_id+= str(model["litellm_params"][key])
|
|
||||||
model["litellm_params"]["model"] += "-ModelID-" + model_id
|
|
||||||
|
|
||||||
self.print_verbose(f"\n Initialized Model List {self.model_list}")
|
self.print_verbose(f"\n Initialized Model List {self.model_list}")
|
||||||
|
|
||||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||||
|
@ -1115,38 +1128,41 @@ class Router:
|
||||||
if specific_deployment == True:
|
if specific_deployment == True:
|
||||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||||
for deployment in self.model_list:
|
for deployment in self.model_list:
|
||||||
cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model"))
|
deployment_model = deployment.get("litellm_params").get("model")
|
||||||
if cleaned_model == model:
|
if deployment_model == model:
|
||||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||||
# return the first deployment where the `model` matches the specificed deployment name
|
# return the first deployment where the `model` matches the specificed deployment name
|
||||||
return deployment
|
return deployment
|
||||||
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
|
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
|
||||||
|
|
||||||
# check if aliases set on litellm model alias map
|
# check if aliases set on litellm model alias map
|
||||||
if model in litellm.model_group_alias_map:
|
if model in self.model_group_alias:
|
||||||
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {litellm.model_group_alias_map.get(model)}")
|
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}")
|
||||||
model = litellm.model_group_alias_map[model]
|
model = self.model_group_alias[model]
|
||||||
|
|
||||||
## get healthy deployments
|
## get healthy deployments
|
||||||
### get all deployments
|
### get all deployments
|
||||||
### filter out the deployments currently cooling down
|
|
||||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
# check if the user sent in a deployment name instead
|
# check if the user sent in a deployment name instead
|
||||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||||
|
|
||||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||||
|
|
||||||
|
# filter out the deployments currently cooling down
|
||||||
deployments_to_remove = []
|
deployments_to_remove = []
|
||||||
|
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||||
cooldown_deployments = self._get_cooldown_deployments()
|
cooldown_deployments = self._get_cooldown_deployments()
|
||||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||||
### FIND UNHEALTHY DEPLOYMENTS
|
# Find deployments in model_list whose model_id is cooling down
|
||||||
for deployment in healthy_deployments:
|
for deployment in healthy_deployments:
|
||||||
deployment_name = deployment["litellm_params"]["model"]
|
deployment_id = deployment["model_info"]["id"]
|
||||||
if deployment_name in cooldown_deployments:
|
if deployment_id in cooldown_deployments:
|
||||||
deployments_to_remove.append(deployment)
|
deployments_to_remove.append(deployment)
|
||||||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
# remove unhealthy deployments from healthy deployments
|
||||||
for deployment in deployments_to_remove:
|
for deployment in deployments_to_remove:
|
||||||
healthy_deployments.remove(deployment)
|
healthy_deployments.remove(deployment)
|
||||||
|
|
||||||
self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}")
|
self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}")
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
raise ValueError("No models available")
|
raise ValueError("No models available")
|
||||||
|
@ -1222,11 +1238,14 @@ class Router:
|
||||||
raise ValueError("No models available.")
|
raise ValueError("No models available.")
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
|
litellm.cache = None
|
||||||
self.cache.flush_cache()
|
self.cache.flush_cache()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
## clean up on close
|
## clean up on close
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
litellm.__async_success_callback = []
|
||||||
litellm.failure_callback = []
|
litellm.failure_callback = []
|
||||||
|
litellm._async_failure_callback = []
|
||||||
self.flush_cache()
|
self.flush_cache()
|
||||||
|
|
34
litellm/tests/conftest.py
Normal file
34
litellm/tests/conftest.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# conftest.py
|
||||||
|
|
||||||
|
import pytest, sys, os
|
||||||
|
import importlib
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def setup_and_teardown():
|
||||||
|
"""
|
||||||
|
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||||
|
"""
|
||||||
|
curr_dir = os.getcwd() # Get the current working directory
|
||||||
|
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
|
||||||
|
import litellm
|
||||||
|
importlib.reload(litellm)
|
||||||
|
print(litellm)
|
||||||
|
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||||
|
yield
|
||||||
|
|
||||||
|
def pytest_collection_modifyitems(config, items):
|
||||||
|
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||||
|
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
|
||||||
|
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
|
||||||
|
|
||||||
|
# Sort tests based on their names
|
||||||
|
custom_logger_tests.sort(key=lambda x: x.name)
|
||||||
|
other_tests.sort(key=lambda x: x.name)
|
||||||
|
|
||||||
|
# Reorder the items list
|
||||||
|
items[:] = custom_logger_tests + other_tests
|
7
litellm/tests/example_config_yaml/cache_no_params.yaml
Normal file
7
litellm/tests/example_config_yaml/cache_no_params.yaml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: "openai-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
cache: True
|
10
litellm/tests/example_config_yaml/cache_with_params.yaml
Normal file
10
litellm/tests/example_config_yaml/cache_with_params.yaml
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: "openai-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
cache: True
|
||||||
|
cache_params:
|
||||||
|
supported_call_types: ["embedding", "aembedding"]
|
||||||
|
host: "localhost"
|
4
litellm/tests/langfuse.log
Normal file
4
litellm/tests/langfuse.log
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
uploading batch of 2 items
|
||||||
|
successfully uploaded batch of 2 items
|
||||||
|
uploading batch of 2 items
|
||||||
|
successfully uploaded batch of 2 items
|
|
@ -118,6 +118,7 @@ def test_cooldown_same_model_name():
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": "BAD_API_BASE",
|
"api_base": "BAD_API_BASE",
|
||||||
|
"tpm": 90
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -126,7 +127,8 @@ def test_cooldown_same_model_name():
|
||||||
"model": "azure/chatgpt-v-2",
|
"model": "azure/chatgpt-v-2",
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"tpm": 0.000001
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
@ -151,13 +153,14 @@ def test_cooldown_same_model_name():
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print(router.model_list)
|
print(router.model_list)
|
||||||
litellm_model_names = []
|
model_ids = []
|
||||||
for model in router.model_list:
|
for model in router.model_list:
|
||||||
litellm_model_names.append(model["litellm_params"]["model"])
|
model_ids.append(model["model_info"]["id"])
|
||||||
print("\n litellm model names ", litellm_model_names)
|
print("\n litellm model ids ", model_ids)
|
||||||
|
|
||||||
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
|
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
|
||||||
assert litellm_model_names[0] != litellm_model_names[1] # ensure both models have a uuid added, and they have different names
|
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names
|
||||||
|
|
||||||
print("\ngot response\n", response)
|
print("\ngot response\n", response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Got unexpected exception on router! - {e}")
|
pytest.fail(f"Got unexpected exception on router! - {e}")
|
||||||
|
|
|
@ -9,9 +9,9 @@ import os, io
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import pytest
|
import pytest, asyncio
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import embedding, completion, completion_cost, Timeout
|
from litellm import embedding, completion, completion_cost, Timeout, acompletion
|
||||||
from litellm import RateLimitError
|
from litellm import RateLimitError
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
@ -63,6 +63,27 @@ def load_vertex_ai_credentials():
|
||||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def get_response():
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||||
|
try:
|
||||||
|
response = await acompletion(
|
||||||
|
model="gemini-pro",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
|
||||||
|
},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
except litellm.UnprocessableEntityError as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_vertex_ai():
|
def test_vertex_ai():
|
||||||
import random
|
import random
|
||||||
|
@ -72,14 +93,15 @@ def test_vertex_ai():
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=False
|
||||||
litellm.vertex_project = "hardy-device-386718"
|
litellm.vertex_project = "hardy-device-386718"
|
||||||
|
|
||||||
test_models = random.sample(test_models, 4)
|
test_models = random.sample(test_models, 1)
|
||||||
|
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||||
for model in test_models:
|
for model in test_models:
|
||||||
try:
|
try:
|
||||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||||
# our account does not have access to this model
|
# our account does not have access to this model
|
||||||
continue
|
continue
|
||||||
print("making request", model)
|
print("making request", model)
|
||||||
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}])
|
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7)
|
||||||
print("\nModel Response", response)
|
print("\nModel Response", response)
|
||||||
print(response)
|
print(response)
|
||||||
assert type(response.choices[0].message.content) == str
|
assert type(response.choices[0].message.content) == str
|
||||||
|
@ -95,10 +117,11 @@ def test_vertex_ai_stream():
|
||||||
import random
|
import random
|
||||||
|
|
||||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||||
test_models = random.sample(test_models, 4)
|
test_models = random.sample(test_models, 1)
|
||||||
|
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||||
for model in test_models:
|
for model in test_models:
|
||||||
try:
|
try:
|
||||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||||
# our account does not have access to this model
|
# our account does not have access to this model
|
||||||
continue
|
continue
|
||||||
print("making request", model)
|
print("making request", model)
|
||||||
|
@ -115,3 +138,199 @@ def test_vertex_ai_stream():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_vertex_ai_stream()
|
# test_vertex_ai_stream()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_vertexai_response():
|
||||||
|
import random
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||||
|
test_models = random.sample(test_models, 1)
|
||||||
|
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||||
|
for model in test_models:
|
||||||
|
print(f'model being tested in async call: {model}')
|
||||||
|
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||||
|
# our account does not have access to this model
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
|
||||||
|
print(f"response: {response}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_vertexai_response())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_vertexai_streaming_response():
|
||||||
|
import random
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||||
|
test_models = random.sample(test_models, 1)
|
||||||
|
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||||
|
for model in test_models:
|
||||||
|
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||||
|
# our account does not have access to this model
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
user_message = "Hello, how are you?"
|
||||||
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
|
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
|
||||||
|
print(f"response: {response}")
|
||||||
|
complete_response = ""
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
complete_response += chunk.choices[0].delta.content
|
||||||
|
print(f"complete_response: {complete_response}")
|
||||||
|
assert len(complete_response) > 0
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_vertexai_streaming_response())
|
||||||
|
|
||||||
|
def test_gemini_pro_vision():
|
||||||
|
try:
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
litellm.num_retries=0
|
||||||
|
resp = litellm.completion(
|
||||||
|
model = "vertex_ai/gemini-pro-vision",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Whats in this image?"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {
|
||||||
|
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(resp)
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
raise e
|
||||||
|
# test_gemini_pro_vision()
|
||||||
|
|
||||||
|
|
||||||
|
# Extra gemini Vision tests for completion + stream, async, async + stream
|
||||||
|
# if we run into issues with gemini, we will also add these to our ci/cd pipeline
|
||||||
|
# def test_gemini_pro_vision_stream():
|
||||||
|
# try:
|
||||||
|
# litellm.set_verbose = False
|
||||||
|
# litellm.num_retries=0
|
||||||
|
# print("streaming response from gemini-pro-vision")
|
||||||
|
# resp = litellm.completion(
|
||||||
|
# model = "vertex_ai/gemini-pro-vision",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {
|
||||||
|
# "type": "text",
|
||||||
|
# "text": "Whats in this image?"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "type": "image_url",
|
||||||
|
# "image_url": {
|
||||||
|
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# stream=True
|
||||||
|
# )
|
||||||
|
# print(resp)
|
||||||
|
# for chunk in resp:
|
||||||
|
# print(chunk)
|
||||||
|
# except Exception as e:
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc()
|
||||||
|
# raise e
|
||||||
|
# test_gemini_pro_vision_stream()
|
||||||
|
|
||||||
|
# def test_gemini_pro_vision_async():
|
||||||
|
# try:
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
# litellm.num_retries=0
|
||||||
|
# async def test():
|
||||||
|
# resp = await litellm.acompletion(
|
||||||
|
# model = "vertex_ai/gemini-pro-vision",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {
|
||||||
|
# "type": "text",
|
||||||
|
# "text": "Whats in this image?"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "type": "image_url",
|
||||||
|
# "image_url": {
|
||||||
|
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
# print("async response gemini pro vision")
|
||||||
|
# print(resp)
|
||||||
|
# asyncio.run(test())
|
||||||
|
# except Exception as e:
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc()
|
||||||
|
# raise e
|
||||||
|
# test_gemini_pro_vision_async()
|
||||||
|
|
||||||
|
|
||||||
|
# def test_gemini_pro_vision_async_stream():
|
||||||
|
# try:
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
# litellm.num_retries=0
|
||||||
|
# async def test():
|
||||||
|
# resp = await litellm.acompletion(
|
||||||
|
# model = "vertex_ai/gemini-pro-vision",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {
|
||||||
|
# "type": "text",
|
||||||
|
# "text": "Whats in this image?"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "type": "image_url",
|
||||||
|
# "image_url": {
|
||||||
|
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# stream=True
|
||||||
|
# )
|
||||||
|
# print("async response gemini pro vision")
|
||||||
|
# print(resp)
|
||||||
|
# for chunk in resp:
|
||||||
|
# print(chunk)
|
||||||
|
# asyncio.run(test())
|
||||||
|
# except Exception as e:
|
||||||
|
# import traceback
|
||||||
|
# traceback.print_exc()
|
||||||
|
# raise e
|
||||||
|
# test_gemini_pro_vision_async()
|
|
@ -29,16 +29,19 @@ def generate_random_word(length=4):
|
||||||
messages = [{"role": "user", "content": "who is ishaan 5222"}]
|
messages = [{"role": "user", "content": "who is ishaan 5222"}]
|
||||||
def test_caching_v2(): # test in memory cache
|
def test_caching_v2(): # test in memory cache
|
||||||
try:
|
try:
|
||||||
|
litellm.set_verbose=True
|
||||||
litellm.cache = Cache()
|
litellm.cache = Cache()
|
||||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
litellm.cache = None # disable cache
|
litellm.cache = None # disable cache
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||||
print(f"response1: {response1}")
|
print(f"response1: {response1}")
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred:")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error occurred: {traceback.format_exc()}")
|
print(f"error occurred: {traceback.format_exc()}")
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
@ -58,6 +61,8 @@ def test_caching_with_models_v2():
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
print(f"response3: {response3}")
|
print(f"response3: {response3}")
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||||
# if models are different, it should not return cached response
|
# if models are different, it should not return cached response
|
||||||
print(f"response2: {response2}")
|
print(f"response2: {response2}")
|
||||||
|
@ -91,6 +96,8 @@ def test_embedding_caching():
|
||||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||||
print(f"embedding1: {embedding1}")
|
print(f"embedding1: {embedding1}")
|
||||||
|
@ -145,6 +152,8 @@ def test_embedding_caching_azure():
|
||||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||||
|
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||||
print(f"embedding1: {embedding1}")
|
print(f"embedding1: {embedding1}")
|
||||||
|
@ -175,6 +184,8 @@ def test_redis_cache_completion():
|
||||||
print("\nresponse 3", response3)
|
print("\nresponse 3", response3)
|
||||||
print("\nresponse 4", response4)
|
print("\nresponse 4", response4)
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
|
||||||
"""
|
"""
|
||||||
1 & 2 should be exactly the same
|
1 & 2 should be exactly the same
|
||||||
|
@ -226,6 +237,8 @@ def test_redis_cache_completion_stream():
|
||||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
@ -271,11 +284,53 @@ def test_redis_cache_acompletion_stream():
|
||||||
print("\nresponse 2", response_2_content)
|
print("\nresponse 2", response_2_content)
|
||||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise e
|
raise e
|
||||||
# test_redis_cache_acompletion_stream()
|
# test_redis_cache_acompletion_stream()
|
||||||
|
|
||||||
|
def test_redis_cache_acompletion_stream_bedrock():
|
||||||
|
import asyncio
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
random_word = generate_random_word()
|
||||||
|
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
print("test for caching, streaming + completion")
|
||||||
|
response_1_content = ""
|
||||||
|
response_2_content = ""
|
||||||
|
|
||||||
|
async def call1():
|
||||||
|
nonlocal response_1_content
|
||||||
|
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||||
|
async for chunk in response1:
|
||||||
|
print(chunk)
|
||||||
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
|
print(response_1_content)
|
||||||
|
asyncio.run(call1())
|
||||||
|
time.sleep(0.5)
|
||||||
|
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||||
|
|
||||||
|
async def call2():
|
||||||
|
nonlocal response_2_content
|
||||||
|
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||||
|
async for chunk in response2:
|
||||||
|
print(chunk)
|
||||||
|
response_2_content += chunk.choices[0].delta.content or ""
|
||||||
|
print(response_2_content)
|
||||||
|
asyncio.run(call2())
|
||||||
|
print("\nresponse 1", response_1_content)
|
||||||
|
print("\nresponse 2", response_2_content)
|
||||||
|
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
raise e
|
||||||
|
# test_redis_cache_acompletion_stream_bedrock()
|
||||||
# redis cache with custom keys
|
# redis cache with custom keys
|
||||||
def custom_get_cache_key(*args, **kwargs):
|
def custom_get_cache_key(*args, **kwargs):
|
||||||
# return key to use for your cache:
|
# return key to use for your cache:
|
||||||
|
@ -312,9 +367,44 @@ def test_custom_redis_cache_with_key():
|
||||||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||||
pytest.fail(f"Error occurred:")
|
pytest.fail(f"Error occurred:")
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
|
||||||
# test_custom_redis_cache_with_key()
|
# test_custom_redis_cache_with_key()
|
||||||
|
|
||||||
|
def test_cache_override():
|
||||||
|
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
|
||||||
|
# in this case it should not return cached responses
|
||||||
|
litellm.cache = Cache()
|
||||||
|
print("Testing cache override")
|
||||||
|
litellm.set_verbose=True
|
||||||
|
|
||||||
|
# test embedding
|
||||||
|
response1 = embedding(
|
||||||
|
model = "text-embedding-ada-002",
|
||||||
|
input=[
|
||||||
|
"hello who are you"
|
||||||
|
],
|
||||||
|
caching = False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
response2 = embedding(
|
||||||
|
model = "text-embedding-ada-002",
|
||||||
|
input=[
|
||||||
|
"hello who are you"
|
||||||
|
],
|
||||||
|
caching = False
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||||
|
|
||||||
|
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached.
|
||||||
|
# test_cache_override()
|
||||||
|
|
||||||
|
|
||||||
def test_custom_redis_cache_params():
|
def test_custom_redis_cache_params():
|
||||||
# test if we can init redis with **kwargs
|
# test if we can init redis with **kwargs
|
||||||
|
@ -333,6 +423,8 @@ def test_custom_redis_cache_params():
|
||||||
|
|
||||||
print(litellm.cache.cache.redis_client)
|
print(litellm.cache.cache.redis_client)
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred:", e)
|
pytest.fail(f"Error occurred:", e)
|
||||||
|
|
||||||
|
@ -340,15 +432,58 @@ def test_custom_redis_cache_params():
|
||||||
def test_get_cache_key():
|
def test_get_cache_key():
|
||||||
from litellm.caching import Cache
|
from litellm.caching import Cache
|
||||||
try:
|
try:
|
||||||
|
print("Testing get_cache_key")
|
||||||
cache_instance = Cache()
|
cache_instance = Cache()
|
||||||
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||||
)
|
)
|
||||||
|
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||||
|
)
|
||||||
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
||||||
|
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
|
||||||
|
|
||||||
|
embedding_cache_key = cache_instance.get_cache_key(
|
||||||
|
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
|
||||||
|
'api_key': '', 'api_version': '2023-07-01-preview',
|
||||||
|
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
|
||||||
|
'caching': True,
|
||||||
|
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(embedding_cache_key)
|
||||||
|
|
||||||
|
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
|
||||||
|
|
||||||
|
# Proxy - embedding cache, test if embedding key, gets model_group and not model
|
||||||
|
embedding_cache_key_2 = cache_instance.get_cache_key(
|
||||||
|
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
|
||||||
|
'api_key': '', 'api_version': '2023-07-01-preview',
|
||||||
|
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
|
||||||
|
'caching': True,
|
||||||
|
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
|
||||||
|
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings',
|
||||||
|
'method': 'POST',
|
||||||
|
'headers':
|
||||||
|
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json',
|
||||||
|
'content-length': '80'},
|
||||||
|
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}},
|
||||||
|
'user': None,
|
||||||
|
'metadata': {'user_api_key': None,
|
||||||
|
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'},
|
||||||
|
'model_group': 'EMBEDDING_MODEL_GROUP',
|
||||||
|
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'},
|
||||||
|
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'},
|
||||||
|
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'}
|
||||||
|
)
|
||||||
|
|
||||||
|
print(embedding_cache_key_2)
|
||||||
|
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
|
||||||
|
print("passed!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
pytest.fail(f"Error occurred:", e)
|
pytest.fail(f"Error occurred:", e)
|
||||||
|
|
||||||
# test_get_cache_key()
|
test_get_cache_key()
|
||||||
|
|
||||||
# test_custom_redis_cache_params()
|
# test_custom_redis_cache_params()
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,13 @@ messages = [{"content": user_message, "role": "user"}]
|
||||||
def logger_fn(user_model_dict):
|
def logger_fn(user_model_dict):
|
||||||
print(f"user_model_dict: {user_model_dict}")
|
print(f"user_model_dict: {user_model_dict}")
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_callbacks():
|
||||||
|
print("\npytest fixture - resetting callbacks")
|
||||||
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
|
litellm.failure_callback = []
|
||||||
|
litellm.callbacks = []
|
||||||
|
|
||||||
def test_completion_custom_provider_model_name():
|
def test_completion_custom_provider_model_name():
|
||||||
try:
|
try:
|
||||||
|
@ -61,6 +68,25 @@ def test_completion_claude():
|
||||||
|
|
||||||
# test_completion_claude()
|
# test_completion_claude()
|
||||||
|
|
||||||
|
def test_completion_mistral_api():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose=True
|
||||||
|
response = completion(
|
||||||
|
model="mistral/mistral-tiny",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hey, how's it going?",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
safe_mode = True
|
||||||
|
)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_completion_mistral_api()
|
||||||
|
|
||||||
def test_completion_claude2_1():
|
def test_completion_claude2_1():
|
||||||
try:
|
try:
|
||||||
print("claude2.1 test request")
|
print("claude2.1 test request")
|
||||||
|
@ -287,7 +313,7 @@ def hf_test_completion_tgi():
|
||||||
print(response)
|
print(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
hf_test_completion_tgi()
|
# hf_test_completion_tgi()
|
||||||
|
|
||||||
# ################### Hugging Face Conversational models ########################
|
# ################### Hugging Face Conversational models ########################
|
||||||
# def hf_test_completion_conv():
|
# def hf_test_completion_conv():
|
||||||
|
@ -611,7 +637,7 @@ def test_completion_azure_key_completion_arg():
|
||||||
os.environ.pop("AZURE_API_KEY", None)
|
os.environ.pop("AZURE_API_KEY", None)
|
||||||
try:
|
try:
|
||||||
print("azure gpt-3.5 test\n\n")
|
print("azure gpt-3.5 test\n\n")
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=True
|
||||||
## Test azure call
|
## Test azure call
|
||||||
response = completion(
|
response = completion(
|
||||||
model="azure/chatgpt-v-2",
|
model="azure/chatgpt-v-2",
|
||||||
|
@ -696,6 +722,7 @@ def test_completion_azure():
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
cost = completion_cost(completion_response=response)
|
cost = completion_cost(completion_response=response)
|
||||||
|
assert cost > 0.0
|
||||||
print("Cost for azure completion request", cost)
|
print("Cost for azure completion request", cost)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
@ -1013,15 +1040,56 @@ def test_completion_together_ai():
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
cost = completion_cost(completion_response=response)
|
cost = completion_cost(completion_response=response)
|
||||||
|
assert cost > 0.0
|
||||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
def test_completion_together_ai_mixtral():
|
||||||
|
model_name = "together_ai/DiscoResearch/DiscoLM-mixtral-8x7b-v2"
|
||||||
|
try:
|
||||||
|
messages =[
|
||||||
|
{"role": "user", "content": "Who are you"},
|
||||||
|
{"role": "assistant", "content": "I am your helpful assistant."},
|
||||||
|
{"role": "user", "content": "Tell me a joke"},
|
||||||
|
]
|
||||||
|
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
assert cost > 0.0
|
||||||
|
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
test_completion_together_ai_mixtral()
|
||||||
|
|
||||||
|
def test_completion_together_ai_yi_chat():
|
||||||
|
model_name = "together_ai/zero-one-ai/Yi-34B-Chat"
|
||||||
|
try:
|
||||||
|
messages =[
|
||||||
|
{"role": "user", "content": "What llm are you?"},
|
||||||
|
]
|
||||||
|
response = completion(model=model_name, messages=messages)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(response)
|
||||||
|
cost = completion_cost(completion_response=response)
|
||||||
|
assert cost > 0.0
|
||||||
|
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_completion_together_ai_yi_chat()
|
||||||
|
|
||||||
# test_completion_together_ai()
|
# test_completion_together_ai()
|
||||||
def test_customprompt_together_ai():
|
def test_customprompt_together_ai():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
litellm.num_retries = 0
|
litellm.num_retries = 0
|
||||||
|
print("in test_customprompt_together_ai")
|
||||||
|
print(litellm.success_callback)
|
||||||
|
print(litellm._async_success_callback)
|
||||||
response = completion(
|
response = completion(
|
||||||
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
|
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -1030,7 +1098,6 @@ def test_customprompt_together_ai():
|
||||||
print(response)
|
print(response)
|
||||||
except litellm.exceptions.Timeout as e:
|
except litellm.exceptions.Timeout as e:
|
||||||
print(f"Timeout Error")
|
print(f"Timeout Error")
|
||||||
litellm.num_retries = 3 # reset retries
|
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"ERROR TYPE {type(e)}")
|
print(f"ERROR TYPE {type(e)}")
|
||||||
|
|
|
@ -2,7 +2,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
import inspect
|
import inspect
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
class MyCustomHandler(CustomLogger):
|
class testCustomCallbackProxy(CustomLogger):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.success: bool = False # type: ignore
|
self.success: bool = False # type: ignore
|
||||||
self.failure: bool = False # type: ignore
|
self.failure: bool = False # type: ignore
|
||||||
|
@ -55,8 +55,11 @@ class MyCustomHandler(CustomLogger):
|
||||||
self.async_success = True
|
self.async_success = True
|
||||||
print("Value of async success: ", self.async_success)
|
print("Value of async success: ", self.async_success)
|
||||||
print("\n kwargs: ", kwargs)
|
print("\n kwargs: ", kwargs)
|
||||||
if kwargs.get("model") == "azure-embedding-model":
|
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
|
||||||
|
print("Got an embedding model", kwargs.get("model"))
|
||||||
|
print("Setting embedding success to True")
|
||||||
self.async_success_embedding = True
|
self.async_success_embedding = True
|
||||||
|
print("Value of async success embedding: ", self.async_success_embedding)
|
||||||
self.async_embedding_kwargs = kwargs
|
self.async_embedding_kwargs = kwargs
|
||||||
self.async_embedding_response = response_obj
|
self.async_embedding_response = response_obj
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
|
@ -79,6 +82,9 @@ class MyCustomHandler(CustomLogger):
|
||||||
# tokens used in response
|
# tokens used in response
|
||||||
usage = response_obj["usage"]
|
usage = response_obj["usage"]
|
||||||
|
|
||||||
|
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
|
||||||
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
f"""
|
f"""
|
||||||
Model: {model},
|
Model: {model},
|
||||||
|
@ -104,4 +110,4 @@ class MyCustomHandler(CustomLogger):
|
||||||
|
|
||||||
self.async_completion_kwargs_fail = kwargs
|
self.async_completion_kwargs_fail = kwargs
|
||||||
|
|
||||||
my_custom_logger = MyCustomHandler()
|
my_custom_logger = testCustomCallbackProxy()
|
16
litellm/tests/test_configs/test_bad_config.yaml
Normal file
16
litellm/tests/test_configs/test_bad_config.yaml
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
api_key: bad-key
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
- model_name: azure-gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/chatgpt-v-2
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: bad-key
|
||||||
|
- model_name: azure-embedding
|
||||||
|
litellm_params:
|
||||||
|
model: azure/azure-embedding-model
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: bad-key
|
||||||
|
|
|
@ -19,3 +19,63 @@ model_list:
|
||||||
model_info:
|
model_info:
|
||||||
description: this is a test openai model
|
description: this is a test openai model
|
||||||
model_name: test_openai_models
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 56f1bd94-3b54-4b67-9ea2-7c70e9a3a709
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 4d1ee26c-abca-450c-8744-8e87fd6755e9
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 00e19c0f-b63d-42bb-88e9-016fb0c60764
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 79fc75bf-8e1b-47d5-8d24-9365a854af03
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
api_base: os.environ/AZURE_API_BASE
|
||||||
|
api_key: os.environ/AZURE_API_KEY
|
||||||
|
api_version: 2023-07-01-preview
|
||||||
|
model: azure/azure-embedding-model
|
||||||
|
model_name: azure-embedding-model
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 55848c55-4162-40f9-a6e2-9a722b9ef404
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 34339b1e-e030-4bcc-a531-c48559f10ce4
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: f6f74e14-ac64-4403-9365-319e584dcdc5
|
||||||
|
model_name: test_openai_models
|
||||||
|
- litellm_params:
|
||||||
|
model: gpt-3.5-turbo
|
||||||
|
model_info:
|
||||||
|
description: this is a test openai model
|
||||||
|
id: 9b1ef341-322c-410a-8992-903987fef439
|
||||||
|
model_name: test_openai_models
|
||||||
|
- model_name: amazon-embeddings
|
||||||
|
litellm_params:
|
||||||
|
model: "bedrock/amazon.titan-embed-text-v1"
|
||||||
|
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
|
||||||
|
litellm_params:
|
||||||
|
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
631
litellm/tests/test_custom_callback_input.py
Normal file
631
litellm/tests/test_custom_callback_input.py
Normal file
|
@ -0,0 +1,631 @@
|
||||||
|
### What this tests ####
|
||||||
|
## This test asserts the type of data passed into each method of the custom callback handler
|
||||||
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
|
from datetime import datetime
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
from typing import Optional, Literal, List, Union
|
||||||
|
from litellm import completion, embedding, Cache
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
# Test Scenarios (test across completion, streaming, embedding)
|
||||||
|
## 1: Pre-API-Call
|
||||||
|
## 2: Post-API-Call
|
||||||
|
## 3: On LiteLLM Call success
|
||||||
|
## 4: On LiteLLM Call failure
|
||||||
|
## 5. Caching
|
||||||
|
|
||||||
|
# Test models
|
||||||
|
## 1. OpenAI
|
||||||
|
## 2. Azure OpenAI
|
||||||
|
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
||||||
|
|
||||||
|
# Test interfaces
|
||||||
|
## 1. litellm.completion() + litellm.embeddings()
|
||||||
|
## refer to test_custom_callback_input_router.py for the router + proxy tests
|
||||||
|
|
||||||
|
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
"""
|
||||||
|
The set of expected inputs to a custom handler for a
|
||||||
|
"""
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
self.errors = []
|
||||||
|
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||||
|
|
||||||
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
try:
|
||||||
|
self.states.append("sync_pre_api_call")
|
||||||
|
## MODEL
|
||||||
|
assert isinstance(model, str)
|
||||||
|
## MESSAGES
|
||||||
|
assert isinstance(messages, list)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("post_api_call")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert end_time == None
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['input'], (list, dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("async_stream")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("sync_success")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("sync_failure")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
try:
|
||||||
|
self.states.append("async_pre_api_call")
|
||||||
|
## MODEL
|
||||||
|
assert isinstance(model, str)
|
||||||
|
## MESSAGES
|
||||||
|
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("async_success")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['input'], (list, dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("async_failure")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['input'], (list, str, dict))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
# COMPLETION
|
||||||
|
## Test OpenAI + sync
|
||||||
|
def test_chat_openai_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync openai"
|
||||||
|
}])
|
||||||
|
## test streaming
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# test_chat_openai_stream()
|
||||||
|
|
||||||
|
## Test OpenAI + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_openai_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}])
|
||||||
|
## test streaming
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_chat_openai_stream())
|
||||||
|
|
||||||
|
## Test Azure + sync
|
||||||
|
def test_chat_azure_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync azure"
|
||||||
|
}])
|
||||||
|
# test streaming
|
||||||
|
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync azure"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
# test failure callback
|
||||||
|
try:
|
||||||
|
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync azure"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# test_chat_azure_stream()
|
||||||
|
|
||||||
|
## Test Azure + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_azure_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async azure"
|
||||||
|
}])
|
||||||
|
## test streaming
|
||||||
|
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async azure"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async azure"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_chat_azure_stream())
|
||||||
|
|
||||||
|
## Test Bedrock + sync
|
||||||
|
def test_chat_bedrock_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync bedrock"
|
||||||
|
}])
|
||||||
|
# test streaming
|
||||||
|
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync bedrock"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
# test failure callback
|
||||||
|
try:
|
||||||
|
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync bedrock"
|
||||||
|
}],
|
||||||
|
aws_region_name="my-bad-region",
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# test_chat_bedrock_stream()
|
||||||
|
|
||||||
|
## Test Bedrock + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_bedrock_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async bedrock"
|
||||||
|
}])
|
||||||
|
# test streaming
|
||||||
|
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async bedrock"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
print(f"response: {response}")
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async bedrock"
|
||||||
|
}],
|
||||||
|
aws_region_name="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_chat_bedrock_stream())
|
||||||
|
|
||||||
|
# EMBEDDING
|
||||||
|
## Test OpenAI + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_openai():
|
||||||
|
try:
|
||||||
|
customHandler_success = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"])
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
assert len(customHandler_success.errors) == 0
|
||||||
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
|
# test failure callback
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
try:
|
||||||
|
response = await litellm.aembedding(model="text-embedding-ada-002",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
api_key="my-bad-key")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_embedding_openai())
|
||||||
|
|
||||||
|
## Test Azure + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_azure():
|
||||||
|
try:
|
||||||
|
customHandler_success = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"])
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
assert len(customHandler_success.errors) == 0
|
||||||
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
|
# test failure callback
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
try:
|
||||||
|
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
api_key="my-bad-key")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_embedding_azure())
|
||||||
|
|
||||||
|
## Test Bedrock + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_bedrock():
|
||||||
|
try:
|
||||||
|
customHandler_success = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||||
|
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
assert len(customHandler_success.errors) == 0
|
||||||
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
|
# test failure callback
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
try:
|
||||||
|
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
aws_region_name="my-bad-region")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_embedding_bedrock())
|
||||||
|
|
||||||
|
# CACHING
|
||||||
|
## Test Azure - completion, embedding
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_completion_azure_caching():
|
||||||
|
customHandler_caching = CompletionCustomHandler()
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
litellm.callbacks = [customHandler_caching]
|
||||||
|
unique_time = time.time()
|
||||||
|
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||||
|
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||||
|
assert len(customHandler_caching.errors) == 0
|
||||||
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_azure_caching():
|
||||||
|
print("Testing custom callback input - Azure Caching")
|
||||||
|
customHandler_caching = CompletionCustomHandler()
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
litellm.callbacks = [customHandler_caching]
|
||||||
|
unique_time = time.time()
|
||||||
|
response1 = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=[f"good morning from litellm1 {unique_time}"],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1) # set cache is async for aembedding()
|
||||||
|
response2 = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=[f"good morning from litellm1 {unique_time}"],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
print(customHandler_caching.states)
|
||||||
|
assert len(customHandler_caching.errors) == 0
|
||||||
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||||
|
|
||||||
|
# asyncio.run(
|
||||||
|
# test_async_embedding_azure_caching()
|
||||||
|
# )
|
488
litellm/tests/test_custom_callback_router.py
Normal file
488
litellm/tests/test_custom_callback_router.py
Normal file
|
@ -0,0 +1,488 @@
|
||||||
|
### What this tests ####
|
||||||
|
## This test asserts the type of data passed into each method of the custom callback handler
|
||||||
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
|
from datetime import datetime
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
from typing import Optional, Literal, List
|
||||||
|
from litellm import Router, Cache
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
# Test Scenarios (test across completion, streaming, embedding)
|
||||||
|
## 1: Pre-API-Call
|
||||||
|
## 2: Post-API-Call
|
||||||
|
## 3: On LiteLLM Call success
|
||||||
|
## 4: On LiteLLM Call failure
|
||||||
|
## fallbacks
|
||||||
|
## retries
|
||||||
|
|
||||||
|
# Test cases
|
||||||
|
## 1. Simple Azure OpenAI acompletion + streaming call
|
||||||
|
## 2. Simple Azure OpenAI aembedding call
|
||||||
|
## 3. Azure OpenAI acompletion + streaming call with retries
|
||||||
|
## 4. Azure OpenAI aembedding call with retries
|
||||||
|
## 5. Azure OpenAI acompletion + streaming call with fallbacks
|
||||||
|
## 6. Azure OpenAI aembedding call with fallbacks
|
||||||
|
|
||||||
|
# Test interfaces
|
||||||
|
## 1. router.completion() + router.embeddings()
|
||||||
|
## 2. proxy.completions + proxy.embeddings
|
||||||
|
|
||||||
|
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
"""
|
||||||
|
The set of expected inputs to a custom handler for a
|
||||||
|
"""
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
self.errors = []
|
||||||
|
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||||
|
|
||||||
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
try:
|
||||||
|
print(f'received kwargs in pre-input: {kwargs}')
|
||||||
|
self.states.append("sync_pre_api_call")
|
||||||
|
## MODEL
|
||||||
|
assert isinstance(model, str)
|
||||||
|
## MESSAGES
|
||||||
|
assert isinstance(messages, list)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
### ROUTER-SPECIFIC KWARGS
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||||
|
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||||
|
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("post_api_call")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert end_time == None
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['input'], (list, dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
### ROUTER-SPECIFIC KWARGS
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||||
|
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||||
|
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("async_stream")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("sync_success")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("sync_failure")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
try:
|
||||||
|
"""
|
||||||
|
No-op.
|
||||||
|
Not implemented yet.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
self.states.append("async_success")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['input'], (list, dict, str))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||||
|
### ROUTER-SPECIFIC KWARGS
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||||
|
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||||
|
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||||
|
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
print(f"received original response: {kwargs['original_response']}")
|
||||||
|
self.states.append("async_failure")
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['input'], (list, str, dict))
|
||||||
|
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
|
||||||
|
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
# Simple Azure OpenAI call
|
||||||
|
## COMPLETION
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_azure():
|
||||||
|
try:
|
||||||
|
customHandler_completion_azure_router = CompletionCustomHandler()
|
||||||
|
customHandler_streaming_azure_router = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_completion_azure_router]
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list) # type: ignore
|
||||||
|
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}])
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
assert len(customHandler_completion_azure_router.errors) == 0
|
||||||
|
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
|
||||||
|
# streaming
|
||||||
|
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||||
|
router2 = Router(model_list=model_list) # type: ignore
|
||||||
|
response = await router2.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"async azure router chunk: {chunk}")
|
||||||
|
continue
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
|
||||||
|
assert len(customHandler_streaming_azure_router.errors) == 0
|
||||||
|
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
|
||||||
|
# failure
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "my-bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
]
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
router3 = Router(model_list=model_list) # type: ignore
|
||||||
|
try:
|
||||||
|
response = await router3.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}])
|
||||||
|
print(f"response in router3 acompletion: {response}")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||||
|
assert "async_failure" in customHandler_failure.states
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
# asyncio.run(test_async_chat_azure())
|
||||||
|
## EMBEDDING
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_azure():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-embedding-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list) # type: ignore
|
||||||
|
response = await router.aembedding(model="azure-embedding-model",
|
||||||
|
input=["hello from litellm!"])
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
assert len(customHandler.states) == 3 # pre, post, success
|
||||||
|
# failure
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "azure-embedding-model", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/azure-embedding-model",
|
||||||
|
"api_key": "my-bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
]
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
router3 = Router(model_list=model_list) # type: ignore
|
||||||
|
try:
|
||||||
|
response = await router3.aembedding(model="azure-embedding-model",
|
||||||
|
input=["hello from litellm!"])
|
||||||
|
print(f"response in router3 aembedding: {response}")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||||
|
assert "async_failure" in customHandler_failure.states
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
# asyncio.run(test_async_embedding_azure())
|
||||||
|
# Azure OpenAI call w/ Fallbacks
|
||||||
|
## COMPLETION
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_azure_with_fallbacks():
|
||||||
|
try:
|
||||||
|
customHandler_fallbacks = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_fallbacks]
|
||||||
|
# with fallbacks
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "my-bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
|
||||||
|
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}])
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
|
||||||
|
assert len(customHandler_fallbacks.errors) == 0
|
||||||
|
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
# asyncio.run(test_async_chat_azure_with_fallbacks())
|
||||||
|
|
||||||
|
# CACHING
|
||||||
|
## Test Azure - completion, embedding
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_completion_azure_caching():
|
||||||
|
customHandler_caching = CompletionCustomHandler()
|
||||||
|
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||||
|
litellm.callbacks = [customHandler_caching]
|
||||||
|
unique_time = time.time()
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list) # type: ignore
|
||||||
|
response1 = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||||
|
response2 = await router.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||||
|
}],
|
||||||
|
caching=True)
|
||||||
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||||
|
assert len(customHandler_caching.errors) == 0
|
||||||
|
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
|
@ -1,5 +1,5 @@
|
||||||
### What this tests ####
|
### What this tests ####
|
||||||
import sys, os, time, inspect, asyncio
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
import pytest
|
import pytest
|
||||||
sys.path.insert(0, os.path.abspath('../..'))
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
|
||||||
|
@ -7,9 +7,8 @@ from litellm import completion, embedding
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
async_success = False
|
|
||||||
complete_streaming_response_in_callback = ""
|
|
||||||
class MyCustomHandler(CustomLogger):
|
class MyCustomHandler(CustomLogger):
|
||||||
|
complete_streaming_response_in_callback = ""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.success: bool = False # type: ignore
|
self.success: bool = False # type: ignore
|
||||||
self.failure: bool = False # type: ignore
|
self.failure: bool = False # type: ignore
|
||||||
|
@ -27,9 +26,12 @@ class MyCustomHandler(CustomLogger):
|
||||||
|
|
||||||
self.stream_collected_response = None # type: ignore
|
self.stream_collected_response = None # type: ignore
|
||||||
self.sync_stream_collected_response = None # type: ignore
|
self.sync_stream_collected_response = None # type: ignore
|
||||||
|
self.user = None # type: ignore
|
||||||
|
self.data_sent_to_api: dict = {}
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
print(f"Pre-API Call")
|
print(f"Pre-API Call")
|
||||||
|
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
|
||||||
|
|
||||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"Post-API Call")
|
print(f"Post-API Call")
|
||||||
|
@ -50,9 +52,8 @@ class MyCustomHandler(CustomLogger):
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Async success")
|
print(f"On Async success")
|
||||||
|
print(f"received kwargs user: {kwargs['user']}")
|
||||||
self.async_success = True
|
self.async_success = True
|
||||||
print("Value of async success: ", self.async_success)
|
|
||||||
print("\n kwargs: ", kwargs)
|
|
||||||
if kwargs.get("model") == "text-embedding-ada-002":
|
if kwargs.get("model") == "text-embedding-ada-002":
|
||||||
self.async_success_embedding = True
|
self.async_success_embedding = True
|
||||||
self.async_embedding_kwargs = kwargs
|
self.async_embedding_kwargs = kwargs
|
||||||
|
@ -60,31 +61,32 @@ class MyCustomHandler(CustomLogger):
|
||||||
if kwargs.get("stream") == True:
|
if kwargs.get("stream") == True:
|
||||||
self.stream_collected_response = response_obj
|
self.stream_collected_response = response_obj
|
||||||
self.async_completion_kwargs = kwargs
|
self.async_completion_kwargs = kwargs
|
||||||
|
self.user = kwargs.get("user", None)
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Async Failure")
|
print(f"On Async Failure")
|
||||||
self.async_failure = True
|
self.async_failure = True
|
||||||
print("Value of async failure: ", self.async_failure)
|
|
||||||
print("\n kwargs: ", kwargs)
|
|
||||||
if kwargs.get("model") == "text-embedding-ada-002":
|
if kwargs.get("model") == "text-embedding-ada-002":
|
||||||
self.async_failure_embedding = True
|
self.async_failure_embedding = True
|
||||||
self.async_embedding_kwargs_fail = kwargs
|
self.async_embedding_kwargs_fail = kwargs
|
||||||
|
|
||||||
self.async_completion_kwargs_fail = kwargs
|
self.async_completion_kwargs_fail = kwargs
|
||||||
|
|
||||||
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time):
|
class TmpFunction:
|
||||||
global async_success, complete_streaming_response_in_callback
|
complete_streaming_response_in_callback = ""
|
||||||
print(f"ON ASYNC LOGGING")
|
async_success: bool = False
|
||||||
async_success = True
|
async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
|
||||||
print("\nKWARGS", kwargs)
|
print(f"ON ASYNC LOGGING")
|
||||||
complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
|
self.async_success = True
|
||||||
|
print(f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}')
|
||||||
|
self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
|
||||||
|
|
||||||
|
|
||||||
def test_async_chat_openai_stream():
|
def test_async_chat_openai_stream():
|
||||||
try:
|
try:
|
||||||
global complete_streaming_response_in_callback
|
tmp_function = TmpFunction()
|
||||||
# litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
litellm.success_callback = [async_test_logging_fn]
|
litellm.success_callback = [tmp_function.async_test_logging_fn]
|
||||||
complete_streaming_response = ""
|
complete_streaming_response = ""
|
||||||
async def call_gpt():
|
async def call_gpt():
|
||||||
nonlocal complete_streaming_response
|
nonlocal complete_streaming_response
|
||||||
|
@ -98,12 +100,16 @@ def test_async_chat_openai_stream():
|
||||||
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
|
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
|
||||||
print(complete_streaming_response)
|
print(complete_streaming_response)
|
||||||
asyncio.run(call_gpt())
|
asyncio.run(call_gpt())
|
||||||
assert complete_streaming_response_in_callback["choices"][0]["message"]["content"] == complete_streaming_response
|
complete_streaming_response = complete_streaming_response.strip("'")
|
||||||
assert async_success == True
|
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
|
||||||
|
response2 = complete_streaming_response
|
||||||
|
# assert [ord(c) for c in response1] == [ord(c) for c in response2]
|
||||||
|
assert response1 == response2
|
||||||
|
assert tmp_function.async_success == True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
pytest.fail(f"An error occurred - {str(e)}")
|
pytest.fail(f"An error occurred - {str(e)}")
|
||||||
test_async_chat_openai_stream()
|
# test_async_chat_openai_stream()
|
||||||
|
|
||||||
def test_completion_azure_stream_moderation_failure():
|
def test_completion_azure_stream_moderation_failure():
|
||||||
try:
|
try:
|
||||||
|
@ -205,13 +211,27 @@ def test_azure_completion_stream():
|
||||||
assert response_in_success_handler == complete_streaming_response
|
assert response_in_success_handler == complete_streaming_response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_azure_completion_stream()
|
|
||||||
|
|
||||||
def test_async_custom_handler():
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_custom_handler_completion():
|
||||||
try:
|
try:
|
||||||
customHandler2 = MyCustomHandler()
|
customHandler_success = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler2]
|
customHandler_failure = MyCustomHandler()
|
||||||
litellm.set_verbose = True
|
# success
|
||||||
|
assert customHandler_success.async_success == False
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello from litellm test",
|
||||||
|
}]
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
assert customHandler_success.async_success == True, "async success is not set to True even after success"
|
||||||
|
assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
||||||
|
# failure
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{
|
{
|
||||||
|
@ -219,77 +239,101 @@ def test_async_custom_handler():
|
||||||
"content": "how do i kill someone",
|
"content": "how do i kill someone",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
async def test_1():
|
|
||||||
try:
|
|
||||||
response = await litellm.acompletion(
|
|
||||||
model="gpt-3.5-turbo",
|
|
||||||
messages=messages,
|
|
||||||
api_key="test",
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
assert customHandler2.async_failure == False
|
assert customHandler_failure.async_failure == False
|
||||||
asyncio.run(test_1())
|
try:
|
||||||
assert customHandler2.async_failure == True, "async failure is not set to True even after failure"
|
|
||||||
assert customHandler2.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
|
|
||||||
assert len(str(customHandler2.async_completion_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
|
||||||
print("Passed setting async failure")
|
|
||||||
|
|
||||||
async def test_2():
|
|
||||||
response = await litellm.acompletion(
|
response = await litellm.acompletion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{
|
messages=messages,
|
||||||
"role": "user",
|
api_key="my-bad-key",
|
||||||
"content": "hello from litellm test",
|
)
|
||||||
}]
|
except:
|
||||||
)
|
pass
|
||||||
print("\n response", response)
|
assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure"
|
||||||
assert customHandler2.async_success == False
|
assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
|
||||||
asyncio.run(test_2())
|
assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||||
assert customHandler2.async_success == True, "async success is not set to True even after success"
|
litellm.callbacks = []
|
||||||
assert customHandler2.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
print("Passed setting async failure")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
# asyncio.run(test_async_custom_handler_completion())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
async def test_3():
|
async def test_async_custom_handler_embedding():
|
||||||
response = await litellm.aembedding(
|
try:
|
||||||
|
customHandler_embedding = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_embedding]
|
||||||
|
# success
|
||||||
|
assert customHandler_embedding.async_success_embedding == False
|
||||||
|
response = await litellm.aembedding(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
input = ["hello world"],
|
input = ["hello world"],
|
||||||
)
|
)
|
||||||
print("\n response", response)
|
await asyncio.sleep(1)
|
||||||
assert customHandler2.async_success_embedding == False
|
assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success"
|
||||||
asyncio.run(test_3())
|
assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
|
||||||
assert customHandler2.async_success_embedding == True, "async_success_embedding is not set to True even after success"
|
assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2
|
||||||
assert customHandler2.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
|
|
||||||
assert customHandler2.async_embedding_response["usage"]["prompt_tokens"] ==2
|
|
||||||
print("Passed setting async success: Embedding")
|
print("Passed setting async success: Embedding")
|
||||||
|
# failure
|
||||||
|
assert customHandler_embedding.async_failure_embedding == False
|
||||||
print("Testing custom failure callback for embedding")
|
try:
|
||||||
|
response = await litellm.aembedding(
|
||||||
async def test_4():
|
model="text-embedding-ada-002",
|
||||||
try:
|
input = ["hello world"],
|
||||||
response = await litellm.aembedding(
|
api_key="my-bad-key",
|
||||||
model="text-embedding-ada-002",
|
)
|
||||||
input = ["hello world"],
|
except:
|
||||||
api_key="test",
|
pass
|
||||||
)
|
assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
|
||||||
except:
|
assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
|
||||||
pass
|
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||||
|
|
||||||
assert customHandler2.async_failure_embedding == False
|
|
||||||
asyncio.run(test_4())
|
|
||||||
assert customHandler2.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
|
|
||||||
assert customHandler2.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
|
|
||||||
assert len(str(customHandler2.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
|
||||||
print("Passed setting async failure")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
# test_async_custom_handler()
|
# asyncio.run(test_async_custom_handler_embedding())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_custom_handler_embedding_optional_param():
|
||||||
|
"""
|
||||||
|
Tests if the openai optional params for embedding - user + encoding_format,
|
||||||
|
are logged
|
||||||
|
"""
|
||||||
|
customHandler_optional_params = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_optional_params]
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="azure/azure-embedding-model",
|
||||||
|
input = ["hello world"],
|
||||||
|
user = "John"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1) # success callback is async
|
||||||
|
assert customHandler_optional_params.user == "John"
|
||||||
|
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
|
||||||
|
|
||||||
|
# asyncio.run(test_async_custom_handler_embedding_optional_param())
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_custom_handler_embedding_optional_param_bedrock():
|
||||||
|
"""
|
||||||
|
Tests if the openai optional params for embedding - user + encoding_format,
|
||||||
|
are logged
|
||||||
|
|
||||||
|
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
|
||||||
|
"""
|
||||||
|
litellm.drop_params = True
|
||||||
|
litellm.set_verbose = True
|
||||||
|
customHandler_optional_params = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_optional_params]
|
||||||
|
response = await litellm.aembedding(
|
||||||
|
model="bedrock/amazon.titan-embed-text-v1",
|
||||||
|
input = ["hello world"],
|
||||||
|
user = "John"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(1) # success callback is async
|
||||||
|
assert customHandler_optional_params.user == "John"
|
||||||
|
assert "user" not in customHandler_optional_params.data_sent_to_api
|
||||||
|
|
||||||
|
|
||||||
from litellm import Cache
|
|
||||||
def test_redis_cache_completion_stream():
|
def test_redis_cache_completion_stream():
|
||||||
|
from litellm import Cache
|
||||||
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
|
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
|
||||||
import random
|
import random
|
||||||
try:
|
try:
|
||||||
|
@ -316,13 +360,10 @@ def test_redis_cache_completion_stream():
|
||||||
print("\nresponse 2", response_2_content)
|
print("\nresponse 2", response_2_content)
|
||||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
|
litellm._async_success_callback = []
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
raise e
|
raise e
|
||||||
"""
|
|
||||||
|
|
||||||
1 & 2 should be exactly the same
|
|
||||||
"""
|
|
||||||
# test_redis_cache_completion_stream()
|
# test_redis_cache_completion_stream()
|
120
litellm/tests/test_dynamodb_logs.py
Normal file
120
litellm/tests/test_dynamodb_logs.py
Normal file
|
@ -0,0 +1,120 @@
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import io, asyncio
|
||||||
|
# import logging
|
||||||
|
# logging.basicConfig(level=logging.DEBUG)
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
|
||||||
|
from litellm import completion
|
||||||
|
import litellm
|
||||||
|
litellm.num_retries = 3
|
||||||
|
|
||||||
|
import time, random
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def pre_request():
|
||||||
|
file_name = f"dynamo.log"
|
||||||
|
log_file = open(file_name, "a+")
|
||||||
|
|
||||||
|
# Clear the contents of the file by truncating it
|
||||||
|
log_file.truncate(0)
|
||||||
|
|
||||||
|
# Save the original stdout so that we can restore it later
|
||||||
|
original_stdout = sys.stdout
|
||||||
|
# Redirect stdout to the file
|
||||||
|
sys.stdout = log_file
|
||||||
|
|
||||||
|
return original_stdout, log_file, file_name
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
def verify_log_file(log_file_path):
|
||||||
|
|
||||||
|
with open(log_file_path, 'r') as log_file:
|
||||||
|
log_content = log_file.read()
|
||||||
|
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content)
|
||||||
|
|
||||||
|
# Define the pattern to search for in the log file
|
||||||
|
pattern = r"Response from DynamoDB:{.*?}"
|
||||||
|
|
||||||
|
# Find all matches in the log content
|
||||||
|
matches = re.findall(pattern, log_content)
|
||||||
|
|
||||||
|
# Print the DynamoDB success log matches
|
||||||
|
print("DynamoDB Success Log Matches:")
|
||||||
|
for match in matches:
|
||||||
|
print(match)
|
||||||
|
|
||||||
|
# Print the total count of lines containing the specified response
|
||||||
|
print(f"Total occurrences of specified response: {len(matches)}")
|
||||||
|
|
||||||
|
# Count the occurrences of successful responses (status code 200 or 201)
|
||||||
|
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match)
|
||||||
|
|
||||||
|
# Print the count of successful responses
|
||||||
|
print(f"Count of successful responses from DynamoDB: {success_count}")
|
||||||
|
assert success_count == 3 # Expect 3 success logs from dynamoDB
|
||||||
|
|
||||||
|
|
||||||
|
def test_dynamo_logging():
|
||||||
|
# all dynamodb requests need to be in one test function
|
||||||
|
# since we are modifying stdout, and pytests runs tests in parallel
|
||||||
|
try:
|
||||||
|
# pre
|
||||||
|
# redirect stdout to log_file
|
||||||
|
|
||||||
|
litellm.success_callback = ["dynamodb"]
|
||||||
|
litellm.dynamodb_table_name = "litellm-logs-1"
|
||||||
|
litellm.set_verbose = True
|
||||||
|
original_stdout, log_file, file_name = pre_request()
|
||||||
|
|
||||||
|
|
||||||
|
print("Testing async dynamoDB logging")
|
||||||
|
async def _test():
|
||||||
|
return await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content":"This is a test"}],
|
||||||
|
max_tokens=100,
|
||||||
|
temperature=0.7,
|
||||||
|
user = "ishaan-2"
|
||||||
|
)
|
||||||
|
response = asyncio.run(_test())
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
|
||||||
|
# streaming + async
|
||||||
|
async def _test2():
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content":"This is a test"}],
|
||||||
|
max_tokens=10,
|
||||||
|
temperature=0.7,
|
||||||
|
user = "ishaan-2",
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
async for chunk in response:
|
||||||
|
pass
|
||||||
|
asyncio.run(_test2())
|
||||||
|
|
||||||
|
# aembedding()
|
||||||
|
async def _test3():
|
||||||
|
return await litellm.aembedding(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
input = ["hi"],
|
||||||
|
user = "ishaan-2"
|
||||||
|
)
|
||||||
|
response = asyncio.run(_test3())
|
||||||
|
time.sleep(1)
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
finally:
|
||||||
|
# post, close log file and verify
|
||||||
|
# Reset stdout to the original value
|
||||||
|
sys.stdout = original_stdout
|
||||||
|
# Close the file
|
||||||
|
log_file.close()
|
||||||
|
verify_log_file(file_name)
|
||||||
|
print("Passed! Testing async dynamoDB logging")
|
||||||
|
|
||||||
|
# test_dynamo_logging_async()
|
|
@ -164,7 +164,7 @@ def test_bedrock_embedding_titan():
|
||||||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_bedrock_embedding_titan()
|
test_bedrock_embedding_titan()
|
||||||
|
|
||||||
def test_bedrock_embedding_cohere():
|
def test_bedrock_embedding_cohere():
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -21,6 +21,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||||
import pytest
|
import pytest
|
||||||
litellm.vertex_project = "pathrise-convert-1606954137718"
|
litellm.vertex_project = "pathrise-convert-1606954137718"
|
||||||
litellm.vertex_location = "us-central1"
|
litellm.vertex_location = "us-central1"
|
||||||
|
litellm.num_retries=0
|
||||||
|
|
||||||
# litellm.failure_callback = ["sentry"]
|
# litellm.failure_callback = ["sentry"]
|
||||||
#### What this tests ####
|
#### What this tests ####
|
||||||
|
@ -38,10 +39,11 @@ models = ["command-nightly"]
|
||||||
# Test 1: Context Window Errors
|
# Test 1: Context Window Errors
|
||||||
@pytest.mark.parametrize("model", models)
|
@pytest.mark.parametrize("model", models)
|
||||||
def test_context_window(model):
|
def test_context_window(model):
|
||||||
|
print("Testing context window error")
|
||||||
sample_text = "Say error 50 times" * 1000000
|
sample_text = "Say error 50 times" * 1000000
|
||||||
messages = [{"content": sample_text, "role": "user"}]
|
messages = [{"content": sample_text, "role": "user"}]
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = True
|
||||||
response = completion(model=model, messages=messages)
|
response = completion(model=model, messages=messages)
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
print("FAILED!")
|
print("FAILED!")
|
||||||
|
@ -176,7 +178,7 @@ def test_completion_azure_exception():
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
print("azure gpt-3.5 test\n\n")
|
print("azure gpt-3.5 test\n\n")
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=True
|
||||||
## Test azure call
|
## Test azure call
|
||||||
old_azure_key = os.environ["AZURE_API_KEY"]
|
old_azure_key = os.environ["AZURE_API_KEY"]
|
||||||
os.environ["AZURE_API_KEY"] = "good morning"
|
os.environ["AZURE_API_KEY"] = "good morning"
|
||||||
|
@ -189,6 +191,7 @@ def test_completion_azure_exception():
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
os.environ["AZURE_API_KEY"] = old_azure_key
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
print(response)
|
print(response)
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
|
@ -196,14 +199,14 @@ def test_completion_azure_exception():
|
||||||
print("good job got the correct error for azure when key not set")
|
print("good job got the correct error for azure when key not set")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
test_completion_azure_exception()
|
# test_completion_azure_exception()
|
||||||
|
|
||||||
async def asynctest_completion_azure_exception():
|
async def asynctest_completion_azure_exception():
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
import litellm
|
import litellm
|
||||||
print("azure gpt-3.5 test\n\n")
|
print("azure gpt-3.5 test\n\n")
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=True
|
||||||
## Test azure call
|
## Test azure call
|
||||||
old_azure_key = os.environ["AZURE_API_KEY"]
|
old_azure_key = os.environ["AZURE_API_KEY"]
|
||||||
os.environ["AZURE_API_KEY"] = "good morning"
|
os.environ["AZURE_API_KEY"] = "good morning"
|
||||||
|
@ -226,19 +229,75 @@ async def asynctest_completion_azure_exception():
|
||||||
print("Got wrong exception")
|
print("Got wrong exception")
|
||||||
print("exception", e)
|
print("exception", e)
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# import asyncio
|
# import asyncio
|
||||||
# asyncio.run(
|
# asyncio.run(
|
||||||
# asynctest_completion_azure_exception()
|
# asynctest_completion_azure_exception()
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
|
def asynctest_completion_openai_exception_bad_model():
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
import litellm, asyncio
|
||||||
|
print("azure exception bad model\n\n")
|
||||||
|
litellm.set_verbose=True
|
||||||
|
## Test azure call
|
||||||
|
async def test():
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="openai/gpt-6",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
asyncio.run(test())
|
||||||
|
except openai.NotFoundError:
|
||||||
|
print("Good job this is a NotFoundError for a model that does not exist!")
|
||||||
|
print("Passed")
|
||||||
|
except Exception as e:
|
||||||
|
print("Raised wrong type of exception", type(e))
|
||||||
|
assert isinstance(e, openai.BadRequestError)
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# asynctest_completion_openai_exception_bad_model()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def asynctest_completion_azure_exception_bad_model():
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
import litellm, asyncio
|
||||||
|
print("azure exception bad model\n\n")
|
||||||
|
litellm.set_verbose=True
|
||||||
|
## Test azure call
|
||||||
|
async def test():
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="azure/gpt-12",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
asyncio.run(test())
|
||||||
|
except openai.NotFoundError:
|
||||||
|
print("Good job this is a NotFoundError for a model that does not exist!")
|
||||||
|
print("Passed")
|
||||||
|
except Exception as e:
|
||||||
|
print("Raised wrong type of exception", type(e))
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
# asynctest_completion_azure_exception_bad_model()
|
||||||
|
|
||||||
def test_completion_openai_exception():
|
def test_completion_openai_exception():
|
||||||
# test if openai:gpt raises openai.AuthenticationError
|
# test if openai:gpt raises openai.AuthenticationError
|
||||||
try:
|
try:
|
||||||
import openai
|
import openai
|
||||||
print("openai gpt-3.5 test\n\n")
|
print("openai gpt-3.5 test\n\n")
|
||||||
litellm.set_verbose=False
|
litellm.set_verbose=True
|
||||||
## Test azure call
|
## Test azure call
|
||||||
old_azure_key = os.environ["OPENAI_API_KEY"]
|
old_azure_key = os.environ["OPENAI_API_KEY"]
|
||||||
os.environ["OPENAI_API_KEY"] = "good morning"
|
os.environ["OPENAI_API_KEY"] = "good morning"
|
||||||
|
@ -255,11 +314,38 @@ def test_completion_openai_exception():
|
||||||
print(response)
|
print(response)
|
||||||
except openai.AuthenticationError as e:
|
except openai.AuthenticationError as e:
|
||||||
os.environ["OPENAI_API_KEY"] = old_azure_key
|
os.environ["OPENAI_API_KEY"] = old_azure_key
|
||||||
print("good job got the correct error for openai when key not set")
|
print("OpenAI: good job got the correct error for openai when key not set")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_completion_openai_exception()
|
# test_completion_openai_exception()
|
||||||
|
|
||||||
|
def test_completion_mistral_exception():
|
||||||
|
# test if mistral/mistral-tiny raises openai.AuthenticationError
|
||||||
|
try:
|
||||||
|
import openai
|
||||||
|
print("Testing mistral ai exception mapping")
|
||||||
|
litellm.set_verbose=True
|
||||||
|
## Test azure call
|
||||||
|
old_azure_key = os.environ["MISTRAL_API_KEY"]
|
||||||
|
os.environ["MISTRAL_API_KEY"] = "good morning"
|
||||||
|
response = completion(
|
||||||
|
model="mistral/mistral-tiny",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
print(response)
|
||||||
|
except openai.AuthenticationError as e:
|
||||||
|
os.environ["MISTRAL_API_KEY"] = old_azure_key
|
||||||
|
print("good job got the correct error for openai when key not set")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_completion_mistral_exception()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,33 +9,107 @@ from litellm import completion
|
||||||
import litellm
|
import litellm
|
||||||
litellm.num_retries = 3
|
litellm.num_retries = 3
|
||||||
litellm.success_callback = ["langfuse"]
|
litellm.success_callback = ["langfuse"]
|
||||||
# litellm.set_verbose = True
|
os.environ["LANGFUSE_DEBUG"] = "True"
|
||||||
import time
|
import time
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
def search_logs(log_file_path):
|
||||||
|
"""
|
||||||
|
Searches the given log file for logs containing the "/api/public" string.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- log_file_path (str): The path to the log file to be searched.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- None
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
- Exception: If there are any bad logs found in the log file.
|
||||||
|
"""
|
||||||
|
import re
|
||||||
|
print("\n searching logs")
|
||||||
|
bad_logs = []
|
||||||
|
good_logs = []
|
||||||
|
all_logs = []
|
||||||
|
try:
|
||||||
|
with open(log_file_path, 'r') as log_file:
|
||||||
|
lines = log_file.readlines()
|
||||||
|
print(f"searching logslines: {lines}")
|
||||||
|
for line in lines:
|
||||||
|
all_logs.append(line.strip())
|
||||||
|
if "/api/public" in line:
|
||||||
|
print("Found log with /api/public:")
|
||||||
|
print(line.strip())
|
||||||
|
print("\n\n")
|
||||||
|
match = re.search(r'receive_response_headers.complete return_value=\(b\'HTTP/1.1\', (\d+),', line)
|
||||||
|
if match:
|
||||||
|
status_code = int(match.group(1))
|
||||||
|
if status_code != 200 and status_code != 201:
|
||||||
|
print("got a BAD log")
|
||||||
|
bad_logs.append(line.strip())
|
||||||
|
else:
|
||||||
|
|
||||||
|
good_logs.append(line.strip())
|
||||||
|
print("\nBad Logs")
|
||||||
|
print(bad_logs)
|
||||||
|
if len(bad_logs)>0:
|
||||||
|
raise Exception(f"bad logs, Bad logs = {bad_logs}")
|
||||||
|
|
||||||
|
print("\nGood Logs")
|
||||||
|
print(good_logs)
|
||||||
|
if len(good_logs) <= 0:
|
||||||
|
raise Exception(f"There were no Good Logs from Langfuse. No logs with /api/public status 200. \nAll logs:{all_logs}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def pre_langfuse_setup():
|
||||||
|
"""
|
||||||
|
Set up the logging for the 'pre_langfuse_setup' function.
|
||||||
|
"""
|
||||||
|
# sends logs to langfuse.log
|
||||||
|
import logging
|
||||||
|
# Configure the logging to write to a file
|
||||||
|
logging.basicConfig(filename="langfuse.log", level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# Add a FileHandler to the logger
|
||||||
|
file_handler = logging.FileHandler("langfuse.log", mode='w')
|
||||||
|
file_handler.setLevel(logging.DEBUG)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
return
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||||
def test_langfuse_logging_async():
|
def test_langfuse_logging_async():
|
||||||
try:
|
try:
|
||||||
|
pre_langfuse_setup()
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
async def _test_langfuse():
|
async def _test_langfuse():
|
||||||
return await litellm.acompletion(
|
return await litellm.acompletion(
|
||||||
model="gpt-3.5-turbo",
|
model="gpt-3.5-turbo",
|
||||||
messages=[{"role": "user", "content":"This is a test"}],
|
messages=[{"role": "user", "content":"This is a test"}],
|
||||||
max_tokens=1000,
|
max_tokens=100,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
response = asyncio.run(_test_langfuse())
|
response = asyncio.run(_test_langfuse())
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
# time.sleep(2)
|
||||||
|
# # check langfuse.log to see if there was a failed response
|
||||||
|
# search_logs("langfuse.log")
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"An exception occurred - {e}")
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
|
||||||
# test_langfuse_logging_async()
|
test_langfuse_logging_async()
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||||
def test_langfuse_logging():
|
def test_langfuse_logging():
|
||||||
try:
|
try:
|
||||||
# litellm.set_verbose = True
|
pre_langfuse_setup()
|
||||||
|
litellm.set_verbose = True
|
||||||
response = completion(model="claude-instant-1.2",
|
response = completion(model="claude-instant-1.2",
|
||||||
messages=[{
|
messages=[{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
|
@ -43,17 +117,20 @@ def test_langfuse_logging():
|
||||||
}],
|
}],
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
metadata={"langfuse/key": "foo"}
|
|
||||||
)
|
)
|
||||||
print(response)
|
print(response)
|
||||||
|
# time.sleep(5)
|
||||||
|
# # check langfuse.log to see if there was a failed response
|
||||||
|
# search_logs("langfuse.log")
|
||||||
|
|
||||||
except litellm.Timeout as e:
|
except litellm.Timeout as e:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
|
||||||
test_langfuse_logging()
|
test_langfuse_logging()
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||||
def test_langfuse_logging_stream():
|
def test_langfuse_logging_stream():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose=True
|
litellm.set_verbose=True
|
||||||
|
@ -77,6 +154,7 @@ def test_langfuse_logging_stream():
|
||||||
|
|
||||||
# test_langfuse_logging_stream()
|
# test_langfuse_logging_stream()
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||||
def test_langfuse_logging_custom_generation_name():
|
def test_langfuse_logging_custom_generation_name():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose=True
|
litellm.set_verbose=True
|
||||||
|
@ -99,8 +177,8 @@ def test_langfuse_logging_custom_generation_name():
|
||||||
pytest.fail(f"An exception occurred - {e}")
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
print(e)
|
print(e)
|
||||||
|
|
||||||
test_langfuse_logging_custom_generation_name()
|
# test_langfuse_logging_custom_generation_name()
|
||||||
|
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||||
def test_langfuse_logging_function_calling():
|
def test_langfuse_logging_function_calling():
|
||||||
function1 = [
|
function1 = [
|
||||||
{
|
{
|
||||||
|
|
|
@ -17,10 +17,10 @@ model_alias_map = {
|
||||||
"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"
|
"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"
|
||||||
}
|
}
|
||||||
|
|
||||||
litellm.model_alias_map = model_alias_map
|
|
||||||
|
|
||||||
def test_model_alias_map():
|
def test_model_alias_map():
|
||||||
try:
|
try:
|
||||||
|
litellm.model_alias_map = model_alias_map
|
||||||
response = completion(
|
response = completion(
|
||||||
"good-model",
|
"good-model",
|
||||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
|
|
@ -1,100 +1,37 @@
|
||||||
##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
import sys, os
|
||||||
# import aiohttp
|
import traceback
|
||||||
# import json
|
from dotenv import load_dotenv
|
||||||
# import asyncio
|
|
||||||
# import requests
|
|
||||||
#
|
|
||||||
# async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
|
||||||
# session = aiohttp.ClientSession()
|
|
||||||
# url = f'{api_base}/api/generate'
|
|
||||||
# data = {
|
|
||||||
# "model": model,
|
|
||||||
# "prompt": prompt,
|
|
||||||
# }
|
|
||||||
|
|
||||||
# response = ""
|
load_dotenv()
|
||||||
|
import os, io
|
||||||
|
|
||||||
# try:
|
sys.path.insert(
|
||||||
# async with session.post(url, json=data) as resp:
|
0, os.path.abspath("../..")
|
||||||
# async for line in resp.content.iter_any():
|
) # Adds the parent directory to the system path
|
||||||
# if line:
|
import pytest
|
||||||
# try:
|
import litellm
|
||||||
# json_chunk = line.decode("utf-8")
|
|
||||||
# chunks = json_chunk.split("\n")
|
|
||||||
# for chunk in chunks:
|
|
||||||
# if chunk.strip() != "":
|
|
||||||
# j = json.loads(chunk)
|
|
||||||
# if "response" in j:
|
|
||||||
# print(j["response"])
|
|
||||||
# yield {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": j["response"]
|
|
||||||
# }
|
|
||||||
# # self.responses.append(j["response"])
|
|
||||||
# # yield "blank"
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"Error decoding JSON: {e}")
|
|
||||||
# finally:
|
|
||||||
# await session.close()
|
|
||||||
|
|
||||||
# async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
|
||||||
# generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
|
||||||
# response = ""
|
|
||||||
# async for elem in generator:
|
|
||||||
# print(elem)
|
|
||||||
# response += elem["content"]
|
|
||||||
# return response
|
|
||||||
|
|
||||||
# #generator = get_ollama_response_stream()
|
|
||||||
|
|
||||||
# result = asyncio.run(get_ollama_response_no_stream())
|
|
||||||
# print(result)
|
|
||||||
|
|
||||||
# # return this generator to the client for streaming requests
|
|
||||||
|
|
||||||
|
|
||||||
# async def get_response():
|
## for ollama we can't test making the completion call
|
||||||
# global generator
|
from litellm.utils import get_optional_params, get_llm_provider
|
||||||
# async for elem in generator:
|
|
||||||
# print(elem)
|
|
||||||
|
|
||||||
# asyncio.run(get_response())
|
def test_get_ollama_params():
|
||||||
|
try:
|
||||||
|
converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", max_tokens=20, temperature=0.5, stream=True)
|
||||||
|
print("Converted params", converted_params)
|
||||||
|
assert converted_params == {'num_predict': 20, 'stream': True, 'temperature': 0.5}, f"{converted_params} != {'num_predict': 20, 'stream': True, 'temperature': 0.5}"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
# test_get_ollama_params()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_ollama_model():
|
||||||
|
try:
|
||||||
|
model, custom_llm_provider, _, _ = get_llm_provider("ollama/code-llama-22")
|
||||||
|
print("Model", "custom_llm_provider", model, custom_llm_provider)
|
||||||
|
assert custom_llm_provider == "ollama", f"{custom_llm_provider} != ollama"
|
||||||
|
assert model == "code-llama-22", f"{model} != code-llama-22"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
##### latest implementation of making raw http post requests to local ollama server
|
# test_get_ollama_model()
|
||||||
|
|
||||||
# import requests
|
|
||||||
# import json
|
|
||||||
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
|
||||||
# url = f"{api_base}/api/generate"
|
|
||||||
# data = {
|
|
||||||
# "model": model,
|
|
||||||
# "prompt": prompt,
|
|
||||||
# }
|
|
||||||
# session = requests.Session()
|
|
||||||
|
|
||||||
# with session.post(url, json=data, stream=True) as resp:
|
|
||||||
# for line in resp.iter_lines():
|
|
||||||
# if line:
|
|
||||||
# try:
|
|
||||||
# json_chunk = line.decode("utf-8")
|
|
||||||
# chunks = json_chunk.split("\n")
|
|
||||||
# for chunk in chunks:
|
|
||||||
# if chunk.strip() != "":
|
|
||||||
# j = json.loads(chunk)
|
|
||||||
# if "response" in j:
|
|
||||||
# completion_obj = {
|
|
||||||
# "role": "assistant",
|
|
||||||
# "content": "",
|
|
||||||
# }
|
|
||||||
# completion_obj["content"] = j["response"]
|
|
||||||
# yield {"choices": [{"delta": completion_obj}]}
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"Error decoding JSON: {e}")
|
|
||||||
# session.close()
|
|
||||||
|
|
||||||
# response = get_ollama_response_stream()
|
|
||||||
|
|
||||||
# for chunk in response:
|
|
||||||
# print(chunk['choices'][0]['delta'])
|
|
|
@ -16,6 +16,19 @@
|
||||||
# user_message = "respond in 20 words. who are you?"
|
# user_message = "respond in 20 words. who are you?"
|
||||||
# messages = [{ "content": user_message,"role": "user"}]
|
# messages = [{ "content": user_message,"role": "user"}]
|
||||||
|
|
||||||
|
# async def test_async_ollama_streaming():
|
||||||
|
# try:
|
||||||
|
# litellm.set_verbose = True
|
||||||
|
# response = await litellm.acompletion(model="ollama/mistral-openorca",
|
||||||
|
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||||
|
# stream=True)
|
||||||
|
# async for chunk in response:
|
||||||
|
# print(chunk)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(e)
|
||||||
|
|
||||||
|
# asyncio.run(test_async_ollama_streaming())
|
||||||
|
|
||||||
# def test_completion_ollama():
|
# def test_completion_ollama():
|
||||||
# try:
|
# try:
|
||||||
# response = completion(
|
# response = completion(
|
||||||
|
@ -29,7 +42,7 @@
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_ollama()
|
# # test_completion_ollama()
|
||||||
|
|
||||||
# def test_completion_ollama_with_api_base():
|
# def test_completion_ollama_with_api_base():
|
||||||
# try:
|
# try:
|
||||||
|
@ -42,7 +55,7 @@
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_ollama_with_api_base()
|
# # test_completion_ollama_with_api_base()
|
||||||
|
|
||||||
|
|
||||||
# def test_completion_ollama_custom_prompt_template():
|
# def test_completion_ollama_custom_prompt_template():
|
||||||
|
@ -72,7 +85,7 @@
|
||||||
# traceback.print_exc()
|
# traceback.print_exc()
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_ollama_custom_prompt_template()
|
# # test_completion_ollama_custom_prompt_template()
|
||||||
|
|
||||||
# async def test_completion_ollama_async_stream():
|
# async def test_completion_ollama_async_stream():
|
||||||
# user_message = "what is the weather"
|
# user_message = "what is the weather"
|
||||||
|
@ -98,8 +111,8 @@
|
||||||
# except Exception as e:
|
# except Exception as e:
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# import asyncio
|
# # import asyncio
|
||||||
# asyncio.run(test_completion_ollama_async_stream())
|
# # asyncio.run(test_completion_ollama_async_stream())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,8 +167,35 @@
|
||||||
# pass
|
# pass
|
||||||
# pytest.fail(f"Error occurred: {e}")
|
# pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
# test_completion_expect_error()
|
# # test_completion_expect_error()
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# import asyncio
|
# def test_ollama_llava():
|
||||||
# asyncio.run(main())
|
# litellm.set_verbose=True
|
||||||
|
# # same params as gpt-4 vision
|
||||||
|
# response = completion(
|
||||||
|
# model = "ollama/llava",
|
||||||
|
# messages=[
|
||||||
|
# {
|
||||||
|
# "role": "user",
|
||||||
|
# "content": [
|
||||||
|
# {
|
||||||
|
# "type": "text",
|
||||||
|
# "text": "What is in this picture"
|
||||||
|
# },
|
||||||
|
# {
|
||||||
|
# "type": "image_url",
|
||||||
|
# "image_url": {
|
||||||
|
# "url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"
|
||||||
|
# }
|
||||||
|
# }
|
||||||
|
# ]
|
||||||
|
# }
|
||||||
|
# ],
|
||||||
|
# )
|
||||||
|
# print("Response from ollama/llava")
|
||||||
|
# print(response)
|
||||||
|
# test_ollama_llava()
|
||||||
|
|
||||||
|
|
||||||
|
# # PROCESSED CHUNK PRE CHUNK CREATOR
|
||||||
|
|
27
litellm/tests/test_optional_params.py
Normal file
27
litellm/tests/test_optional_params.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests if get_optional_params works as expected
|
||||||
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
import litellm
|
||||||
|
from litellm.utils import get_optional_params_embeddings
|
||||||
|
## get_optional_params_embeddings
|
||||||
|
### Models: OpenAI, Azure, Bedrock
|
||||||
|
### Scenarios: w/ optional params + litellm.drop_params = True
|
||||||
|
|
||||||
|
def test_bedrock_optional_params_embeddings():
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="bedrock")
|
||||||
|
assert len(optional_params) == 0
|
||||||
|
|
||||||
|
def test_openai_optional_params_embeddings():
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="openai")
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
assert optional_params["user"] == "John"
|
||||||
|
|
||||||
|
def test_azure_optional_params_embeddings():
|
||||||
|
litellm.drop_params = True
|
||||||
|
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="azure")
|
||||||
|
assert len(optional_params) == 1
|
||||||
|
assert optional_params["user"] == "John"
|
|
@ -19,21 +19,23 @@ from litellm import RateLimitError
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
|
||||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(router) # Include your router in the test app
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def wrapper_startup_event():
|
|
||||||
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
|
||||||
|
|
||||||
# Here you create a fixture that will be used by your tests
|
# Here you create a fixture that will be used by your tests
|
||||||
# Make sure the fixture returns TestClient(app)
|
# Make sure the fixture returns TestClient(app)
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(scope="function")
|
||||||
def client():
|
def client():
|
||||||
with TestClient(app) as client:
|
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||||
yield client
|
cleanup_router_config_variables()
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||||
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
|
app = FastAPI()
|
||||||
|
initialize(config=config_fp)
|
||||||
|
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
def test_custom_auth(client):
|
def test_custom_auth(client):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -3,7 +3,7 @@ import traceback
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
import os, io
|
import os, io, asyncio
|
||||||
|
|
||||||
# this file is to test litellm/proxy
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
@ -21,21 +21,24 @@ from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
|
||||||
python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
||||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(router) # Include your router in the test app
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def wrapper_startup_event():
|
|
||||||
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
|
||||||
|
|
||||||
# Here you create a fixture that will be used by your tests
|
# @app.on_event("startup")
|
||||||
# Make sure the fixture returns TestClient(app)
|
# async def wrapper_startup_event():
|
||||||
@pytest.fixture(autouse=True)
|
# initialize(config=config_fp)
|
||||||
|
|
||||||
|
# Use the app fixture in your client fixture
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
def client():
|
def client():
|
||||||
with TestClient(app) as client:
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
yield client
|
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||||
|
initialize(config=config_fp)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = os.getenv("PROXY_MASTER_KEY")
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
@ -45,15 +48,76 @@ headers = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion(client):
|
print("Testing proxy custom logger")
|
||||||
|
|
||||||
|
def test_embedding(client):
|
||||||
try:
|
try:
|
||||||
# Your test data
|
litellm.set_verbose=False
|
||||||
|
from litellm.proxy.utils import get_instance_fn
|
||||||
|
my_custom_logger = get_instance_fn(
|
||||||
|
value = "custom_callbacks.my_custom_logger",
|
||||||
|
config_file_path=python_file_path
|
||||||
|
)
|
||||||
|
print("id of initialized custom logger", id(my_custom_logger))
|
||||||
|
litellm.callbacks = [my_custom_logger]
|
||||||
|
# Your test data
|
||||||
print("initialized proxy")
|
print("initialized proxy")
|
||||||
# import the initialized custom logger
|
# import the initialized custom logger
|
||||||
print(litellm.callbacks)
|
print(litellm.callbacks)
|
||||||
|
|
||||||
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||||
my_custom_logger = litellm.callbacks[0]
|
print("my_custom_logger", my_custom_logger)
|
||||||
|
assert my_custom_logger.async_success_embedding == False
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"model": "azure-embedding-model",
|
||||||
|
"input": ["hello"]
|
||||||
|
}
|
||||||
|
response = client.post("/embeddings", json=test_data, headers=headers)
|
||||||
|
print("made request", response.status_code, response.text)
|
||||||
|
print("vars my custom logger /embeddings", vars(my_custom_logger), "id", id(my_custom_logger))
|
||||||
|
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||||
|
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
|
||||||
|
kwargs = my_custom_logger.async_embedding_kwargs
|
||||||
|
litellm_params = kwargs.get("litellm_params")
|
||||||
|
metadata = litellm_params.get("metadata", None)
|
||||||
|
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||||
|
assert metadata is not None
|
||||||
|
assert "user_api_key" in metadata
|
||||||
|
assert "headers" in metadata
|
||||||
|
proxy_server_request = litellm_params.get("proxy_server_request")
|
||||||
|
model_info = litellm_params.get("model_info")
|
||||||
|
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
|
||||||
|
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
|
||||||
|
result = response.json()
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
print("Passed Embedding custom logger on proxy!")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_completion(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
|
||||||
|
print("initialized proxy")
|
||||||
|
litellm.set_verbose=False
|
||||||
|
from litellm.proxy.utils import get_instance_fn
|
||||||
|
my_custom_logger = get_instance_fn(
|
||||||
|
value = "custom_callbacks.my_custom_logger",
|
||||||
|
config_file_path=python_file_path
|
||||||
|
)
|
||||||
|
|
||||||
|
print("id of initialized custom logger", id(my_custom_logger))
|
||||||
|
|
||||||
|
litellm.callbacks = [my_custom_logger]
|
||||||
|
# import the initialized custom logger
|
||||||
|
print(litellm.callbacks)
|
||||||
|
|
||||||
|
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||||
|
|
||||||
|
print("LiteLLM Callbacks", litellm.callbacks)
|
||||||
|
print("my_custom_logger", my_custom_logger)
|
||||||
assert my_custom_logger.async_success == False
|
assert my_custom_logger.async_success == False
|
||||||
|
|
||||||
test_data = {
|
test_data = {
|
||||||
|
@ -61,7 +125,7 @@ def test_chat_completion(client):
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "hi"
|
"content": "write a litellm poem"
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
|
@ -70,33 +134,53 @@ def test_chat_completion(client):
|
||||||
|
|
||||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||||
print("made request", response.status_code, response.text)
|
print("made request", response.status_code, response.text)
|
||||||
|
print("LiteLLM Callbacks", litellm.callbacks)
|
||||||
|
asyncio.sleep(1) # sleep while waiting for callback to run
|
||||||
|
|
||||||
|
print("my_custom_logger in /chat/completions", my_custom_logger, "id", id(my_custom_logger))
|
||||||
|
print("vars my custom logger, ", vars(my_custom_logger))
|
||||||
assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||||
assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct
|
assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct
|
||||||
print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs)
|
print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs)
|
||||||
|
|
||||||
litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params")
|
litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params")
|
||||||
|
metadata = litellm_params.get("metadata", None)
|
||||||
|
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||||
|
assert metadata is not None
|
||||||
|
assert "user_api_key" in metadata
|
||||||
|
assert "headers" in metadata
|
||||||
config_model_info = litellm_params.get("model_info")
|
config_model_info = litellm_params.get("model_info")
|
||||||
proxy_server_request_object = litellm_params.get("proxy_server_request")
|
proxy_server_request_object = litellm_params.get("proxy_server_request")
|
||||||
|
|
||||||
assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'}
|
assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'}
|
||||||
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '105', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'hi'}], 'max_tokens': 10}}
|
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '123', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'write a litellm poem'}], 'max_tokens': 10}}
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
print("\nPassed /chat/completions with Custom Logger!")
|
print("\nPassed /chat/completions with Custom Logger!")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_stream(client):
|
def test_chat_completion_stream(client):
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
litellm.set_verbose=False
|
||||||
|
from litellm.proxy.utils import get_instance_fn
|
||||||
|
my_custom_logger = get_instance_fn(
|
||||||
|
value = "custom_callbacks.my_custom_logger",
|
||||||
|
config_file_path=python_file_path
|
||||||
|
)
|
||||||
|
|
||||||
|
print("id of initialized custom logger", id(my_custom_logger))
|
||||||
|
|
||||||
|
litellm.callbacks = [my_custom_logger]
|
||||||
import json
|
import json
|
||||||
print("initialized proxy")
|
print("initialized proxy")
|
||||||
# import the initialized custom logger
|
# import the initialized custom logger
|
||||||
print(litellm.callbacks)
|
print(litellm.callbacks)
|
||||||
|
|
||||||
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
|
||||||
my_custom_logger = litellm.callbacks[0]
|
print("LiteLLM Callbacks", litellm.callbacks)
|
||||||
|
print("my_custom_logger", my_custom_logger)
|
||||||
|
|
||||||
assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call
|
assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call
|
||||||
|
|
||||||
|
@ -148,37 +232,5 @@ def test_chat_completion_stream(client):
|
||||||
assert complete_response == streamed_response["choices"][0]["message"]["content"]
|
assert complete_response == streamed_response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_embedding(client):
|
|
||||||
try:
|
|
||||||
# Your test data
|
|
||||||
print("initialized proxy")
|
|
||||||
# import the initialized custom logger
|
|
||||||
print(litellm.callbacks)
|
|
||||||
|
|
||||||
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
|
||||||
my_custom_logger = litellm.callbacks[0]
|
|
||||||
assert my_custom_logger.async_success_embedding == False
|
|
||||||
|
|
||||||
test_data = {
|
|
||||||
"model": "azure-embedding-model",
|
|
||||||
"input": ["hello"]
|
|
||||||
}
|
|
||||||
response = client.post("/embeddings", json=test_data, headers=headers)
|
|
||||||
print("made request", response.status_code, response.text)
|
|
||||||
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
|
||||||
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
|
|
||||||
|
|
||||||
kwargs = my_custom_logger.async_embedding_kwargs
|
|
||||||
litellm_params = kwargs.get("litellm_params")
|
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request")
|
|
||||||
model_info = litellm_params.get("model_info")
|
|
||||||
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
|
|
||||||
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
|
|
||||||
result = response.json()
|
|
||||||
print(f"Received response: {result}")
|
|
||||||
except Exception as e:
|
|
||||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
|
177
litellm/tests/test_proxy_exception_mapping.py
Normal file
177
litellm/tests/test_proxy_exception_mapping.py
Normal file
|
@ -0,0 +1,177 @@
|
||||||
|
# test that the proxy actually does exception mapping to the OpenAI format
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os, io, asyncio
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm, openai
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
||||||
|
initialize(config=config_fp)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
# raise openai.AuthenticationError
|
||||||
|
def test_chat_completion_exception(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/chat/completions", json=test_data)
|
||||||
|
|
||||||
|
# make an openai client to call _make_status_error_from_response
|
||||||
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||||
|
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
# raise openai.AuthenticationError
|
||||||
|
def test_chat_completion_exception_azure(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "azure-gpt-3.5-turbo",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/chat/completions", json=test_data)
|
||||||
|
|
||||||
|
# make an openai client to call _make_status_error_from_response
|
||||||
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||||
|
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
# raise openai.AuthenticationError
|
||||||
|
def test_embedding_auth_exception_azure(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "azure-embedding",
|
||||||
|
"input": ["hi"]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/embeddings", json=test_data)
|
||||||
|
print("Response from proxy=", response)
|
||||||
|
|
||||||
|
# make an openai client to call _make_status_error_from_response
|
||||||
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||||
|
print("Exception raised=", openai_exception)
|
||||||
|
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# raise openai.BadRequestError
|
||||||
|
# chat/completions openai
|
||||||
|
def test_exception_openai_bad_model(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "azure/GPT-12",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/chat/completions", json=test_data)
|
||||||
|
|
||||||
|
# make an openai client to call _make_status_error_from_response
|
||||||
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||||
|
print("Type of exception=", type(openai_exception))
|
||||||
|
assert isinstance(openai_exception, openai.NotFoundError)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
# chat/completions any model
|
||||||
|
def test_chat_completion_exception_any_model(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "Lite-GPT-12",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/chat/completions", json=test_data)
|
||||||
|
|
||||||
|
# make an openai client to call _make_status_error_from_response
|
||||||
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||||
|
print("Exception raised=", openai_exception)
|
||||||
|
assert isinstance(openai_exception, openai.NotFoundError)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# embeddings any model
|
||||||
|
def test_embedding_exception_any_model(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "Lite-GPT-12",
|
||||||
|
"input": ["hi"]
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.post("/embeddings", json=test_data)
|
||||||
|
print("Response from proxy=", response)
|
||||||
|
|
||||||
|
# make an openai client to call _make_status_error_from_response
|
||||||
|
openai_client = openai.OpenAI(api_key="anything")
|
||||||
|
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||||
|
print("Exception raised=", openai_exception)
|
||||||
|
assert isinstance(openai_exception, openai.NotFoundError)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
|
||||||
|
|
|
@ -24,30 +24,29 @@ logging.basicConfig(
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
|
||||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
|
||||||
app = FastAPI()
|
|
||||||
app.include_router(router) # Include your router in the test app
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def wrapper_startup_event():
|
|
||||||
initialize(config=config_fp)
|
|
||||||
|
|
||||||
# Your bearer token
|
# Your bearer token
|
||||||
token = os.getenv("PROXY_MASTER_KEY")
|
token = ""
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {token}"
|
"Authorization": f"Bearer {token}"
|
||||||
}
|
}
|
||||||
|
|
||||||
# Here you create a fixture that will be used by your tests
|
@pytest.fixture(scope="function")
|
||||||
# Make sure the fixture returns TestClient(app)
|
def client_no_auth():
|
||||||
@pytest.fixture(autouse=True)
|
# Assuming litellm.proxy.proxy_server is an object
|
||||||
def client():
|
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||||
with TestClient(app) as client:
|
cleanup_router_config_variables()
|
||||||
yield client
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||||
|
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||||
|
initialize(config=config_fp)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
||||||
def test_chat_completion(client):
|
return TestClient(app)
|
||||||
|
|
||||||
|
def test_chat_completion(client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
@ -62,8 +61,8 @@ def test_chat_completion(client):
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
print("testing proxy server")
|
print("testing proxy server with chat completions")
|
||||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
print(f"response - {response.text}")
|
print(f"response - {response.text}")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -73,7 +72,8 @@ def test_chat_completion(client):
|
||||||
|
|
||||||
# Run the test
|
# Run the test
|
||||||
|
|
||||||
def test_chat_completion_azure(client):
|
def test_chat_completion_azure(client_no_auth):
|
||||||
|
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
|
@ -88,8 +88,8 @@ def test_chat_completion_azure(client):
|
||||||
"max_tokens": 10,
|
"max_tokens": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
print("testing proxy server with Azure Request")
|
print("testing proxy server with Azure Request /chat/completions")
|
||||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -102,15 +102,55 @@ def test_chat_completion_azure(client):
|
||||||
# test_chat_completion_azure()
|
# test_chat_completion_azure()
|
||||||
|
|
||||||
|
|
||||||
def test_embedding(client):
|
def test_embedding(client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
|
from litellm.proxy.proxy_server import user_custom_auth
|
||||||
|
|
||||||
try:
|
try:
|
||||||
test_data = {
|
test_data = {
|
||||||
"model": "azure/azure-embedding-model",
|
"model": "azure/azure-embedding-model",
|
||||||
"input": ["good morning from litellm"],
|
"input": ["good morning from litellm"],
|
||||||
}
|
}
|
||||||
print("testing proxy server with OpenAI embedding")
|
|
||||||
response = client.post("/v1/embeddings", json=test_data, headers=headers)
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
print(len(result["data"][0]["embedding"]))
|
||||||
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
def test_bedrock_embedding(client_no_auth):
|
||||||
|
global headers
|
||||||
|
from litellm.proxy.proxy_server import user_custom_auth
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_data = {
|
||||||
|
"model": "amazon-embeddings",
|
||||||
|
"input": ["good morning from litellm"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
print(len(result["data"][0]["embedding"]))
|
||||||
|
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||||
|
|
||||||
|
def test_sagemaker_embedding(client_no_auth):
|
||||||
|
global headers
|
||||||
|
from litellm.proxy.proxy_server import user_custom_auth
|
||||||
|
|
||||||
|
try:
|
||||||
|
test_data = {
|
||||||
|
"model": "GPT-J 6B - Sagemaker Text Embedding (Internal)",
|
||||||
|
"input": ["good morning from litellm"],
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||||
|
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
@ -122,8 +162,8 @@ def test_embedding(client):
|
||||||
# Run the test
|
# Run the test
|
||||||
# test_embedding()
|
# test_embedding()
|
||||||
|
|
||||||
@pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
|
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
|
||||||
def test_add_new_model(client):
|
def test_add_new_model(client_no_auth):
|
||||||
global headers
|
global headers
|
||||||
try:
|
try:
|
||||||
test_data = {
|
test_data = {
|
||||||
|
@ -135,15 +175,15 @@ def test_add_new_model(client):
|
||||||
"description": "this is a test openai model"
|
"description": "this is a test openai model"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
client.post("/model/new", json=test_data, headers=headers)
|
client_no_auth.post("/model/new", json=test_data, headers=headers)
|
||||||
response = client.get("/model/info", headers=headers)
|
response = client_no_auth.get("/model/info", headers=headers)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"response: {result}")
|
print(f"response: {result}")
|
||||||
model_info = None
|
model_info = None
|
||||||
for m in result["data"]:
|
for m in result["data"]:
|
||||||
if m["id"]["model_name"] == "test_openai_models":
|
if m["model_name"] == "test_openai_models":
|
||||||
model_info = m["id"]["model_info"]
|
model_info = m["model_info"]
|
||||||
assert model_info["description"] == "this is a test openai model"
|
assert model_info["description"] == "this is a test openai model"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||||
|
@ -164,10 +204,9 @@ class MyCustomHandler(CustomLogger):
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
|
|
||||||
|
|
||||||
def test_chat_completion_optional_params(client):
|
def test_chat_completion_optional_params(client_no_auth):
|
||||||
# [PROXY: PROD TEST] - DO NOT DELETE
|
# [PROXY: PROD TEST] - DO NOT DELETE
|
||||||
# This tests if all the /chat/completion params are passed to litellm
|
# This tests if all the /chat/completion params are passed to litellm
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Your test data
|
# Your test data
|
||||||
litellm.set_verbose=True
|
litellm.set_verbose=True
|
||||||
|
@ -185,7 +224,7 @@ def test_chat_completion_optional_params(client):
|
||||||
|
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
print("testing proxy server: optional params")
|
print("testing proxy server: optional params")
|
||||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
result = response.json()
|
result = response.json()
|
||||||
print(f"Received response: {result}")
|
print(f"Received response: {result}")
|
||||||
|
@ -217,6 +256,29 @@ def test_load_router_config():
|
||||||
print(result)
|
print(result)
|
||||||
assert len(result[1]) == 2
|
assert len(result[1]) == 2
|
||||||
|
|
||||||
|
# tests for litellm.cache set from config
|
||||||
|
print("testing reading proxy config for cache")
|
||||||
|
litellm.cache = None
|
||||||
|
load_router_config(
|
||||||
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml"
|
||||||
|
)
|
||||||
|
assert litellm.cache is not None
|
||||||
|
assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy
|
||||||
|
assert litellm.cache.supported_call_types == ['completion', 'acompletion', 'embedding', 'aembedding'] # init with all call types
|
||||||
|
|
||||||
|
print("testing reading proxy config for cache with params")
|
||||||
|
load_router_config(
|
||||||
|
router=None,
|
||||||
|
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml"
|
||||||
|
)
|
||||||
|
assert litellm.cache is not None
|
||||||
|
print(litellm.cache)
|
||||||
|
print(litellm.cache.supported_call_types)
|
||||||
|
print(vars(litellm.cache.cache))
|
||||||
|
assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy
|
||||||
|
assert litellm.cache.supported_call_types == ['embedding', 'aembedding'] # init with all call types
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail("Proxy: Got exception reading config", e)
|
pytest.fail("Proxy: Got exception reading config", e)
|
||||||
# test_load_router_config()
|
# test_load_router_config()
|
|
@ -37,6 +37,8 @@ async def wrapper_startup_event():
|
||||||
# Make sure the fixture returns TestClient(app)
|
# Make sure the fixture returns TestClient(app)
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def client():
|
def client():
|
||||||
|
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||||
|
cleanup_router_config_variables()
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
@ -69,6 +71,38 @@ def test_add_new_key(client):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_new_key(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
|
||||||
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
|
"duration": "20m"
|
||||||
|
}
|
||||||
|
print("testing proxy server")
|
||||||
|
# Your bearer token
|
||||||
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}"
|
||||||
|
}
|
||||||
|
response = client.post("/key/generate", json=test_data, headers=headers)
|
||||||
|
print(f"response: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
result = response.json()
|
||||||
|
assert result["key"].startswith("sk-")
|
||||||
|
def _post_data():
|
||||||
|
json_data = {'models': ['bedrock-models'], "key": result["key"]}
|
||||||
|
response = client.post("/key/update", json=json_data, headers=headers)
|
||||||
|
print(f"response text: {response.text}")
|
||||||
|
assert response.status_code == 200
|
||||||
|
return response
|
||||||
|
_post_data()
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||||
|
|
||||||
# # Run the test - only runs via pytest
|
# # Run the test - only runs via pytest
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -366,69 +366,12 @@ def test_function_calling():
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list, routing_strategy="latency-based-routing")
|
router = Router(model_list=model_list)
|
||||||
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
||||||
router.reset()
|
router.reset()
|
||||||
print(response)
|
print(response)
|
||||||
|
|
||||||
def test_acompletion_on_router():
|
# test_acompletion_on_router()
|
||||||
# tests acompletion + caching on router
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
model_list = [
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "gpt-3.5-turbo-0613",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
"tpm": 100000,
|
|
||||||
"rpm": 10000,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE"),
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION")
|
|
||||||
},
|
|
||||||
"tpm": 100000,
|
|
||||||
"rpm": 10000,
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
|
||||||
]
|
|
||||||
start_time = time.time()
|
|
||||||
router = Router(model_list=model_list,
|
|
||||||
redis_host=os.environ["REDIS_HOST"],
|
|
||||||
redis_password=os.environ["REDIS_PASSWORD"],
|
|
||||||
redis_port=os.environ["REDIS_PORT"],
|
|
||||||
cache_responses=True,
|
|
||||||
timeout=30,
|
|
||||||
routing_strategy="simple-shuffle")
|
|
||||||
async def get_response():
|
|
||||||
print("Testing acompletion + caching on router")
|
|
||||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
|
||||||
print(f"response1: {response1}")
|
|
||||||
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
|
||||||
print(f"response2: {response2}")
|
|
||||||
assert response1.id == response2.id
|
|
||||||
assert len(response1.choices[0].message.content) > 0
|
|
||||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
|
||||||
asyncio.run(get_response())
|
|
||||||
router.reset()
|
|
||||||
except litellm.Timeout as e:
|
|
||||||
end_time = time.time()
|
|
||||||
print(f"timeout error occurred: {end_time - start_time}")
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
traceback.print_exc()
|
|
||||||
pytest.fail(f"Error occurred: {e}")
|
|
||||||
|
|
||||||
test_acompletion_on_router()
|
|
||||||
|
|
||||||
def test_function_calling_on_router():
|
def test_function_calling_on_router():
|
||||||
try:
|
try:
|
||||||
|
@ -507,7 +450,6 @@ def test_aembedding_on_router():
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
input=["good morning from litellm 2"],
|
input=["good morning from litellm 2"],
|
||||||
)
|
)
|
||||||
print("sync embedding response: ", response)
|
|
||||||
router.reset()
|
router.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -591,6 +533,30 @@ def test_bedrock_on_router():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
# test_bedrock_on_router()
|
# test_bedrock_on_router()
|
||||||
|
|
||||||
|
# test openai-compatible endpoint
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mistral_on_router():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "mistral/mistral-medium",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
router = Router(model_list=model_list)
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hello from litellm test",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
asyncio.run(test_mistral_on_router())
|
||||||
|
|
||||||
def test_openai_completion_on_router():
|
def test_openai_completion_on_router():
|
||||||
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
||||||
|
|
127
litellm/tests/test_router_caching.py
Normal file
127
litellm/tests/test_router_caching.py
Normal file
|
@ -0,0 +1,127 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests caching on the router
|
||||||
|
import sys, os, time
|
||||||
|
import traceback, asyncio
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm import Router
|
||||||
|
|
||||||
|
## Scenarios
|
||||||
|
## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo",
|
||||||
|
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_caching_on_router():
|
||||||
|
# tests acompletion + caching on router
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION")
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
start_time = time.time()
|
||||||
|
router = Router(model_list=model_list,
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
cache_responses=True,
|
||||||
|
timeout=30,
|
||||||
|
routing_strategy="simple-shuffle")
|
||||||
|
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||||
|
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
assert len(response1.choices[0].message.content) > 0
|
||||||
|
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||||
|
router.reset()
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"timeout error occurred: {end_time - start_time}")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_caching_on_router_caching_groups():
|
||||||
|
# tests acompletion + caching on router
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "openai-gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-3.5-turbo-0613",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure-gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION")
|
||||||
|
},
|
||||||
|
"tpm": 100000,
|
||||||
|
"rpm": 10000,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||||
|
]
|
||||||
|
start_time = time.time()
|
||||||
|
router = Router(model_list=model_list,
|
||||||
|
redis_host=os.environ["REDIS_HOST"],
|
||||||
|
redis_password=os.environ["REDIS_PASSWORD"],
|
||||||
|
redis_port=os.environ["REDIS_PORT"],
|
||||||
|
cache_responses=True,
|
||||||
|
timeout=30,
|
||||||
|
routing_strategy="simple-shuffle",
|
||||||
|
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
|
||||||
|
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response1: {response1}")
|
||||||
|
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||||
|
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||||
|
print(f"response2: {response2}")
|
||||||
|
assert response1.id == response2.id
|
||||||
|
assert len(response1.choices[0].message.content) > 0
|
||||||
|
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||||
|
router.reset()
|
||||||
|
except litellm.Timeout as e:
|
||||||
|
end_time = time.time()
|
||||||
|
print(f"timeout error occurred: {end_time - start_time}")
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
|
@ -21,80 +21,89 @@ class MyCustomHandler(CustomLogger):
|
||||||
print(f"Pre-API Call")
|
print(f"Pre-API Call")
|
||||||
|
|
||||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"Post-API Call")
|
print(f"Post-API Call - response object: {response_obj}; model: {kwargs['model']}")
|
||||||
|
|
||||||
|
|
||||||
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
def log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Stream")
|
print(f"On Stream")
|
||||||
|
|
||||||
|
def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"On Stream")
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}")
|
print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}")
|
||||||
self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||||
print(f"self.previous_models: {self.previous_models}")
|
print(f"self.previous_models: {self.previous_models}")
|
||||||
print(f"On Success")
|
print(f"On Success")
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(f"previous_models: {kwargs['litellm_params']['metadata']['previous_models']}")
|
||||||
|
self.previous_models += len(kwargs["litellm_params"]["metadata"]["previous_models"]) # {"previous_models": [{"model": litellm_model_name, "exception_type": AuthenticationError, "exception_string": <complete_traceback>}]}
|
||||||
|
print(f"self.previous_models: {self.previous_models}")
|
||||||
|
print(f"On Success")
|
||||||
|
|
||||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print(f"On Failure")
|
print(f"On Failure")
|
||||||
|
|
||||||
model_list = [
|
|
||||||
{ # list of model deployments
|
|
||||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": "bad-key",
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
|
||||||
},
|
|
||||||
"tpm": 240000,
|
|
||||||
"rpm": 1800
|
|
||||||
},
|
|
||||||
{ # list of model deployments
|
|
||||||
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-v-2",
|
|
||||||
"api_key": os.getenv("AZURE_API_KEY"),
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
|
||||||
},
|
|
||||||
"tpm": 240000,
|
|
||||||
"rpm": 1800
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "azure/chatgpt-functioncalling",
|
|
||||||
"api_key": "bad-key",
|
|
||||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
|
||||||
"api_base": os.getenv("AZURE_API_BASE")
|
|
||||||
},
|
|
||||||
"tpm": 240000,
|
|
||||||
"rpm": 1800
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "gpt-3.5-turbo",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
"tpm": 1000000,
|
|
||||||
"rpm": 9000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
|
||||||
"litellm_params": { # params for litellm completion/embedding call
|
|
||||||
"model": "gpt-3.5-turbo-16k",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
"tpm": 1000000,
|
|
||||||
"rpm": 9000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
|
kwargs = {"model": "azure/gpt-3.5-turbo", "messages": [{"role": "user", "content":"Hey, how's it going?"}]}
|
||||||
|
|
||||||
def test_sync_fallbacks():
|
def test_sync_fallbacks():
|
||||||
try:
|
try:
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
|
@ -106,62 +115,93 @@ def test_sync_fallbacks():
|
||||||
print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
||||||
|
|
||||||
|
print("Passed ! Test router_fallbacks: test_sync_fallbacks()")
|
||||||
router.reset()
|
router.reset()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
# test_sync_fallbacks()
|
# test_sync_fallbacks()
|
||||||
|
|
||||||
def test_async_fallbacks():
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_fallbacks():
|
||||||
litellm.set_verbose = False
|
litellm.set_verbose = False
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list,
|
router = Router(model_list=model_list,
|
||||||
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
||||||
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
||||||
set_verbose=False)
|
set_verbose=False)
|
||||||
async def test_get_response():
|
customHandler = MyCustomHandler()
|
||||||
customHandler = MyCustomHandler()
|
litellm.callbacks = [customHandler]
|
||||||
litellm.callbacks = [customHandler]
|
user_message = "Hello, how are you?"
|
||||||
user_message = "Hello, how are you?"
|
messages = [{"content": user_message, "role": "user"}]
|
||||||
messages = [{"content": user_message, "role": "user"}]
|
try:
|
||||||
try:
|
response = await router.acompletion(**kwargs)
|
||||||
response = await router.acompletion(**kwargs)
|
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
||||||
print(f"customHandler.previous_models: {customHandler.previous_models}")
|
await asyncio.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
router.reset()
|
||||||
router.reset()
|
except litellm.Timeout as e:
|
||||||
except litellm.Timeout as e:
|
pass
|
||||||
pass
|
except Exception as e:
|
||||||
except Exception as e:
|
pytest.fail(f"An exception occurred: {e}")
|
||||||
pytest.fail(f"An exception occurred: {e}")
|
finally:
|
||||||
finally:
|
router.reset()
|
||||||
router.reset()
|
|
||||||
asyncio.run(test_get_response())
|
|
||||||
|
|
||||||
# test_async_fallbacks()
|
# test_async_fallbacks()
|
||||||
|
|
||||||
## COMMENTING OUT as the context size exceeds both gpt-3.5-turbo and gpt-3.5-turbo-16k, need a better message here
|
|
||||||
# def test_sync_context_window_fallbacks():
|
|
||||||
# try:
|
|
||||||
# customHandler = MyCustomHandler()
|
|
||||||
# litellm.callbacks = [customHandler]
|
|
||||||
# sample_text = "Say error 50 times" * 10000
|
|
||||||
# kwargs["model"] = "azure/gpt-3.5-turbo-context-fallback"
|
|
||||||
# kwargs["messages"] = [{"role": "user", "content": sample_text}]
|
|
||||||
# router = Router(model_list=model_list,
|
|
||||||
# fallbacks=[{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}],
|
|
||||||
# context_window_fallbacks=[{"azure/gpt-3.5-turbo-context-fallback": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}],
|
|
||||||
# set_verbose=False)
|
|
||||||
# response = router.completion(**kwargs)
|
|
||||||
# print(f"response: {response}")
|
|
||||||
# time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
|
||||||
# assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
|
||||||
# router.reset()
|
|
||||||
# except Exception as e:
|
|
||||||
# print(f"An exception occurred - {e}")
|
|
||||||
# finally:
|
|
||||||
# router.reset()
|
|
||||||
|
|
||||||
# test_sync_context_window_fallbacks()
|
|
||||||
|
|
||||||
def test_dynamic_fallbacks_sync():
|
def test_dynamic_fallbacks_sync():
|
||||||
"""
|
"""
|
||||||
Allow setting the fallback in the router.completion() call.
|
Allow setting the fallback in the router.completion() call.
|
||||||
|
@ -169,6 +209,60 @@ def test_dynamic_fallbacks_sync():
|
||||||
try:
|
try:
|
||||||
customHandler = MyCustomHandler()
|
customHandler = MyCustomHandler()
|
||||||
litellm.callbacks = [customHandler]
|
litellm.callbacks = [customHandler]
|
||||||
|
model_list = [
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{ # list of model deployments
|
||||||
|
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
router = Router(model_list=model_list, set_verbose=True)
|
router = Router(model_list=model_list, set_verbose=True)
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
kwargs["model"] = "azure/gpt-3.5-turbo"
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
|
@ -184,26 +278,83 @@ def test_dynamic_fallbacks_sync():
|
||||||
|
|
||||||
# test_dynamic_fallbacks_sync()
|
# test_dynamic_fallbacks_sync()
|
||||||
|
|
||||||
def test_dynamic_fallbacks_async():
|
@pytest.mark.asyncio
|
||||||
|
async def test_dynamic_fallbacks_async():
|
||||||
"""
|
"""
|
||||||
Allow setting the fallback in the router.completion() call.
|
Allow setting the fallback in the router.completion() call.
|
||||||
"""
|
"""
|
||||||
async def test_get_response():
|
try:
|
||||||
try:
|
model_list = [
|
||||||
customHandler = MyCustomHandler()
|
{ # list of model deployments
|
||||||
litellm.callbacks = [customHandler]
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
router = Router(model_list=model_list, set_verbose=True)
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
kwargs = {}
|
"model": "azure/chatgpt-v-2",
|
||||||
kwargs["model"] = "azure/gpt-3.5-turbo"
|
"api_key": "bad-key",
|
||||||
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
response = await router.acompletion(**kwargs)
|
},
|
||||||
print(f"response: {response}")
|
"tpm": 240000,
|
||||||
time.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
"rpm": 1800
|
||||||
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
},
|
||||||
router.reset()
|
{ # list of model deployments
|
||||||
except Exception as e:
|
"model_name": "azure/gpt-3.5-turbo-context-fallback", # openai model name
|
||||||
pytest.fail(f"An exception occurred - {e}")
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
asyncio.run(test_get_response())
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "azure/gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "azure/chatgpt-functioncalling",
|
||||||
|
"api_key": "bad-key",
|
||||||
|
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE")
|
||||||
|
},
|
||||||
|
"tpm": 240000,
|
||||||
|
"rpm": 1800
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo-16k", # openai model name
|
||||||
|
"litellm_params": { # params for litellm completion/embedding call
|
||||||
|
"model": "gpt-3.5-turbo-16k",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
"tpm": 1000000,
|
||||||
|
"rpm": 9000
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
# test_dynamic_fallbacks_async()
|
print()
|
||||||
|
print()
|
||||||
|
print()
|
||||||
|
print()
|
||||||
|
print(f"STARTING DYNAMIC ASYNC")
|
||||||
|
customHandler = MyCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
router = Router(model_list=model_list, set_verbose=True)
|
||||||
|
kwargs = {}
|
||||||
|
kwargs["model"] = "azure/gpt-3.5-turbo"
|
||||||
|
kwargs["messages"] = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
kwargs["fallbacks"] = [{"azure/gpt-3.5-turbo": ["gpt-3.5-turbo"]}]
|
||||||
|
response = await router.acompletion(**kwargs)
|
||||||
|
print(f"RESPONSE: {response}")
|
||||||
|
await asyncio.sleep(0.05) # allow a delay as success_callbacks are on a separate thread
|
||||||
|
assert customHandler.previous_models == 1 # 0 retries, 1 fallback
|
||||||
|
router.reset()
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred - {e}")
|
||||||
|
# asyncio.run(test_dynamic_fallbacks_async())
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue