mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge branch 'main' into public-fix-1
This commit is contained in:
commit
47ba8082df
109 changed files with 8257 additions and 3200 deletions
|
@ -61,7 +61,7 @@ jobs:
|
|||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv litellm/tests/ -x --junitxml=test-results/junit.xml --durations=5
|
||||
python -m pytest -vv litellm/tests/ -x --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
|
||||
# Store test results
|
||||
|
@ -78,6 +78,11 @@ jobs:
|
|||
|
||||
steps:
|
||||
- checkout
|
||||
|
||||
- run:
|
||||
name: Copy model_prices_and_context_window File to model_prices_and_context_window_backup
|
||||
command: |
|
||||
cp model_prices_and_context_window.json litellm/model_prices_and_context_window_backup.json
|
||||
|
||||
- run:
|
||||
name: Check if litellm dir was updated or if pyproject.toml was modified
|
||||
|
|
6
.gitignore
vendored
6
.gitignore
vendored
|
@ -19,3 +19,9 @@ litellm/proxy/_secret_config.yaml
|
|||
litellm/tests/aiologs.log
|
||||
litellm/tests/exception_data.txt
|
||||
litellm/tests/config_*.yaml
|
||||
litellm/tests/langfuse.log
|
||||
litellm/tests/test_custom_logger.py
|
||||
litellm/tests/langfuse.log
|
||||
litellm/tests/dynamo*.log
|
||||
.vscode/settings.json
|
||||
litellm/proxy/log.txt
|
||||
|
|
31
Dockerfile
31
Dockerfile
|
@ -1,8 +1,11 @@
|
|||
# Base image
|
||||
ARG LITELLM_BASE_IMAGE=python:3.9-slim
|
||||
ARG LITELLM_BUILD_IMAGE=python:3.9
|
||||
|
||||
# allow users to specify, else use python 3.9-slim
|
||||
FROM $LITELLM_BASE_IMAGE
|
||||
# Runtime image
|
||||
ARG LITELLM_RUNTIME_IMAGE=python:3.9-slim
|
||||
|
||||
# allow users to specify, else use python 3.9
|
||||
FROM $LITELLM_BUILD_IMAGE as builder
|
||||
|
||||
# Set the working directory to /app
|
||||
WORKDIR /app
|
||||
|
@ -16,7 +19,7 @@ RUN pip install --upgrade pip && \
|
|||
pip install build
|
||||
|
||||
# Copy the current directory contents into the container at /app
|
||||
COPY . /app
|
||||
COPY requirements.txt .
|
||||
|
||||
# Build the package
|
||||
RUN rm -rf dist/* && python -m build
|
||||
|
@ -25,13 +28,27 @@ RUN rm -rf dist/* && python -m build
|
|||
RUN pip install dist/*.whl
|
||||
|
||||
# Install any needed packages specified in requirements.txt
|
||||
RUN pip wheel --no-cache-dir --wheel-dir=wheels -r requirements.txt
|
||||
RUN pip install --no-cache-dir --find-links=wheels -r requirements.txt
|
||||
RUN pip install wheel && \
|
||||
pip wheel --no-cache-dir --wheel-dir=/app/wheels -r requirements.txt
|
||||
|
||||
###############################################################################
|
||||
FROM $LITELLM_RUNTIME_IMAGE as runtime
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy the current directory contents into the container at /app
|
||||
COPY . .
|
||||
|
||||
COPY --from=builder /app/wheels /app/wheels
|
||||
|
||||
RUN pip install --no-index --find-links=/app/wheels -r requirements.txt
|
||||
|
||||
# Trigger the Prisma CLI to be installed
|
||||
RUN prisma -v
|
||||
|
||||
EXPOSE 4000/tcp
|
||||
|
||||
# Start the litellm proxy, using the `litellm` cli command https://docs.litellm.ai/docs/simple_proxy
|
||||
|
||||
# Start the litellm proxy with default options
|
||||
CMD ["--port", "4000"]
|
||||
|
||||
|
|
17
README.md
17
README.md
|
@ -62,6 +62,22 @@ response = completion(model="command-nightly", messages=messages)
|
|||
print(response)
|
||||
```
|
||||
|
||||
## Async ([Docs](https://docs.litellm.ai/docs/completion/stream#async-completion))
|
||||
|
||||
```python
|
||||
from litellm import acompletion
|
||||
import asyncio
|
||||
|
||||
async def test_get_response():
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=messages)
|
||||
return response
|
||||
|
||||
response = asyncio.run(test_get_response())
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Streaming ([Docs](https://docs.litellm.ai/docs/completion/stream))
|
||||
liteLLM supports streaming the model response back, pass `stream=True` to get a streaming iterator in response.
|
||||
Streaming is supported for all models (Bedrock, Huggingface, TogetherAI, Azure, OpenAI, etc.)
|
||||
|
@ -140,6 +156,7 @@ response = completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content
|
|||
| [openrouter](https://docs.litellm.ai/docs/providers/openrouter) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [google - vertex_ai](https://docs.litellm.ai/docs/providers/vertex) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [google - palm](https://docs.litellm.ai/docs/providers/palm) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [mistral ai api](https://docs.litellm.ai/docs/providers/mistral) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [ai21](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [baseten](https://docs.litellm.ai/docs/providers/baseten) | ✅ | ✅ | ✅ | ✅ |
|
||||
| [vllm](https://docs.litellm.ai/docs/providers/vllm) | ✅ | ✅ | ✅ | ✅ |
|
||||
|
|
BIN
dist/litellm-1.12.5.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.5.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.5.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.5.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev2-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev2-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev2.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev2.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev3-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev3-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev3.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev3.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev4-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev4-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev4.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev4.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev5-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.12.6.dev5-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.12.6.dev5.tar.gz
vendored
Normal file
BIN
dist/litellm-1.12.6.dev5.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.0.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.14.0.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.0.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.14.0.dev1.tar.gz
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.5.dev1-py3-none-any.whl
vendored
Normal file
BIN
dist/litellm-1.14.5.dev1-py3-none-any.whl
vendored
Normal file
Binary file not shown.
BIN
dist/litellm-1.14.5.dev1.tar.gz
vendored
Normal file
BIN
dist/litellm-1.14.5.dev1.tar.gz
vendored
Normal file
Binary file not shown.
|
@ -55,27 +55,76 @@ litellm.cache = cache # set litellm.cache to your cache
|
|||
|
||||
```
|
||||
|
||||
### Detecting Cached Responses
|
||||
For resposes that were returned as cache hit, the response includes a param `cache` = True
|
||||
## Cache Initialization Parameters
|
||||
|
||||
:::info
|
||||
#### `type` (str, optional)
|
||||
|
||||
Only valid for OpenAI <= 0.28.1 [Let us know if you still need this](https://github.com/BerriAI/litellm/issues/new?assignees=&labels=bug&projects=&template=bug_report.yml&title=%5BBug%5D%3A+)
|
||||
:::
|
||||
The type of cache to initialize. It can be either "local" or "redis". Defaults to "local".
|
||||
|
||||
Example response with cache hit
|
||||
```python
|
||||
{
|
||||
'cache': True,
|
||||
'id': 'chatcmpl-7wggdzd6OXhgE2YhcLJHJNZsEWzZ2',
|
||||
'created': 1694221467,
|
||||
'model': 'gpt-3.5-turbo-0613',
|
||||
'choices': [
|
||||
{
|
||||
'index': 0, 'message': {'role': 'assistant', 'content': 'I\'m sorry, but I couldn\'t find any information about "litellm" or how many stars it has. It is possible that you may be referring to a specific product, service, or platform that I am not familiar with. Can you please provide more context or clarify your question?'
|
||||
}, 'finish_reason': 'stop'}
|
||||
],
|
||||
'usage': {'prompt_tokens': 17, 'completion_tokens': 59, 'total_tokens': 76},
|
||||
}
|
||||
#### `host` (str, optional)
|
||||
|
||||
```
|
||||
The host address for the Redis cache. This parameter is required if the `type` is set to "redis".
|
||||
|
||||
#### `port` (int, optional)
|
||||
|
||||
The port number for the Redis cache. This parameter is required if the `type` is set to "redis".
|
||||
|
||||
#### `password` (str, optional)
|
||||
|
||||
The password for the Redis cache. This parameter is required if the `type` is set to "redis".
|
||||
|
||||
#### `supported_call_types` (list, optional)
|
||||
|
||||
A list of call types to cache for. Defaults to caching for all call types. The available call types are:
|
||||
|
||||
- "completion"
|
||||
- "acompletion"
|
||||
- "embedding"
|
||||
- "aembedding"
|
||||
|
||||
#### `**kwargs` (additional keyword arguments)
|
||||
|
||||
Additional keyword arguments are accepted for the initialization of the Redis cache using the `redis.Redis()` constructor. These arguments allow you to fine-tune the Redis cache configuration based on your specific needs.
|
||||
|
||||
|
||||
## Logging
|
||||
|
||||
Cache hits are logged in success events as `kwarg["cache_hit"]`.
|
||||
|
||||
Here's an example of accessing it:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm import completion, acompletion, Cache
|
||||
|
||||
# create custom callback for success_events
|
||||
class MyCustomHandler(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
print(f"Value of Cache hit: {kwargs['cache_hit']"})
|
||||
|
||||
async def test_async_completion_azure_caching():
|
||||
# set custom callback
|
||||
customHandler_caching = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
|
||||
# init cache
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
unique_time = time.time()
|
||||
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
```
|
||||
|
|
|
@ -4,7 +4,9 @@
|
|||
You can create a custom callback class to precisely log events as they occur in litellm.
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm import completion, acompletion
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
|
@ -21,14 +23,38 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Failure")
|
||||
|
||||
#### ASYNC #### - for acompletion/aembeddings
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Streaming")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success")
|
||||
|
||||
customHandler = MyCustomHandler()
|
||||
|
||||
litellm.callbacks = [customHandler]
|
||||
|
||||
## sync
|
||||
response = completion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
|
||||
|
||||
## async
|
||||
import asyncio
|
||||
|
||||
def async completion():
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
asyncio.run(completion())
|
||||
```
|
||||
|
||||
## Callback Functions
|
||||
|
@ -87,6 +113,41 @@ print(response)
|
|||
|
||||
## Async Callback Functions
|
||||
|
||||
We recommend using the Custom Logger class for async.
|
||||
|
||||
```python
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm import acompletion
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
#### ASYNC ####
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Streaming")
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Success")
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
|
||||
import asyncio
|
||||
customHandler = MyCustomHandler()
|
||||
|
||||
litellm.callbacks = [customHandler]
|
||||
|
||||
def async completion():
|
||||
response = await acompletion(model="gpt-3.5-turbo", messages=[{ "role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
asyncio.run(completion())
|
||||
```
|
||||
|
||||
**Functions**
|
||||
|
||||
If you just want to pass in an async function for logging.
|
||||
|
||||
LiteLLM currently supports just async success callback functions for async completion/embedding calls.
|
||||
|
||||
```python
|
||||
|
@ -117,9 +178,6 @@ asyncio.run(test_chat_openai())
|
|||
:::info
|
||||
|
||||
We're actively trying to expand this to other event types. [Tell us if you need this!](https://github.com/BerriAI/litellm/issues/1007)
|
||||
|
||||
|
||||
|
||||
:::
|
||||
|
||||
## What's in kwargs?
|
||||
|
@ -170,6 +228,48 @@ Here's exactly what you can expect in the kwargs dictionary:
|
|||
"end_time" = end_time # datetime object of when call was completed
|
||||
```
|
||||
|
||||
|
||||
### Cache hits
|
||||
|
||||
Cache hits are logged in success events as `kwarg["cache_hit"]`.
|
||||
|
||||
Here's an example of accessing it:
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm import completion, acompletion, Cache
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Success")
|
||||
print(f"Value of Cache hit: {kwargs['cache_hit']"})
|
||||
|
||||
async def test_async_completion_azure_caching():
|
||||
customHandler_caching = MyCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
```
|
||||
|
||||
### Get complete streaming response
|
||||
|
||||
LiteLLM will pass you the complete streaming response in the final streaming chunk as part of the kwargs for your custom callback function.
|
||||
|
|
|
@ -27,8 +27,8 @@ To get better visualizations on how your code behaves, you may want to annotate
|
|||
|
||||
## Exporting traces to other systems (e.g. Datadog, New Relic, and others)
|
||||
|
||||
Since Traceloop SDK uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [Traceloop docs on exporters](https://traceloop.com/docs/python-sdk/exporters) for more information.
|
||||
Since OpenLLMetry uses OpenTelemetry to send data, you can easily export your traces to other systems, such as Datadog, New Relic, and others. See [OpenLLMetry docs on exporters](https://www.traceloop.com/docs/openllmetry/integrations/introduction) for more information.
|
||||
|
||||
## Support
|
||||
|
||||
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://join.slack.com/t/traceloopcommunity/shared_invite/zt-1plpfpm6r-zOHKI028VkpcWdobX65C~g) or via [email](mailto:dev@traceloop.com).
|
||||
For any question or issue with integration you can reach out to the Traceloop team on [Slack](https://traceloop.com/slack) or via [email](mailto:dev@traceloop.com).
|
||||
|
|
21
docs/my-website/docs/projects/Docq.AI.md
Normal file
21
docs/my-website/docs/projects/Docq.AI.md
Normal file
|
@ -0,0 +1,21 @@
|
|||
**A private and secure ChatGPT alternative that knows your business.**
|
||||
|
||||
Upload docs, ask questions --> get answers.
|
||||
|
||||
Leverage GenAI with your confidential documents to increase efficiency and collaboration.
|
||||
|
||||
OSS core, everything can run in your environment. An extensible platform you can build your GenAI strategy on. Support a variety of popular LLMs including embedded for air gap use cases.
|
||||
|
||||
[![Static Badge][docs-shield]][docs-url]
|
||||
[![Static Badge][github-shield]][github-url]
|
||||
[![X (formerly Twitter) Follow][twitter-shield]][twitter-url]
|
||||
|
||||
<!-- MARKDOWN LINKS & IMAGES -->
|
||||
<!-- https://www.markdownguide.org/basic-syntax/#reference-style-links -->
|
||||
|
||||
[docs-shield]: https://img.shields.io/badge/docs-site-black?logo=materialformkdocs
|
||||
[docs-url]: https://docqai.github.io/docq/
|
||||
[github-shield]: https://img.shields.io/badge/Github-repo-black?logo=github
|
||||
[github-url]: https://github.com/docqai/docq/
|
||||
[twitter-shield]: https://img.shields.io/twitter/follow/docqai?logo=x&style=flat
|
||||
[twitter-url]: https://twitter.com/docqai
|
56
docs/my-website/docs/providers/mistral.md
Normal file
56
docs/my-website/docs/providers/mistral.md
Normal file
|
@ -0,0 +1,56 @@
|
|||
# Mistral AI API
|
||||
https://docs.mistral.ai/api/
|
||||
|
||||
## API Key
|
||||
```python
|
||||
# env variable
|
||||
os.environ['MISTRAL_API_KEY']
|
||||
```
|
||||
|
||||
## Sample Usage
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['MISTRAL_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="mistral/mistral-tiny"",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Sample Usage - Streaming
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
os.environ['MISTRAL_API_KEY'] = ""
|
||||
response = completion(
|
||||
model="mistral/mistral-tiny",
|
||||
messages=[
|
||||
{"role": "user", "content": "hello from litellm"}
|
||||
],
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
|
||||
## Supported Models
|
||||
All models listed here https://docs.mistral.ai/platform/endpoints are supported. We actively maintain the list of models, pricing, token window, etc. [here](https://github.com/BerriAI/litellm/blob/c1b25538277206b9f00de5254d80d6a83bb19a29/model_prices_and_context_window.json).
|
||||
|
||||
| Model Name | Function Call |
|
||||
|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| mistral-tiny | `completion(model="mistral/mistral-tiny", messages)` |
|
||||
| mistral-small | `completion(model="mistral/mistral-small", messages)` |
|
||||
| mistral-medium | `completion(model="mistral/mistral-medium", messages)` |
|
||||
|
||||
|
||||
|
||||
|
||||
|
47
docs/my-website/docs/providers/openai_compatible.md
Normal file
47
docs/my-website/docs/providers/openai_compatible.md
Normal file
|
@ -0,0 +1,47 @@
|
|||
# OpenAI-Compatible Endpoints
|
||||
|
||||
To call models hosted behind an openai proxy, make 2 changes:
|
||||
|
||||
1. Put `openai/` in front of your model name, so litellm knows you're trying to call an openai-compatible endpoint.
|
||||
|
||||
2. **Do NOT** add anything additional to the base url e.g. `/v1/embedding`. LiteLLM uses the openai-client to make these calls, and that automatically adds the relevant endpoints.
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
import litellm
|
||||
from litellm import embedding
|
||||
litellm.set_verbose = True
|
||||
import os
|
||||
|
||||
|
||||
litellm_proxy_endpoint = "http://0.0.0.0:8000"
|
||||
bearer_token = "sk-1234"
|
||||
|
||||
CHOSEN_LITE_LLM_EMBEDDING_MODEL = "openai/GPT-J 6B - Sagemaker Text Embedding (Internal)"
|
||||
|
||||
litellm.set_verbose = False
|
||||
|
||||
print(litellm_proxy_endpoint)
|
||||
|
||||
|
||||
|
||||
response = embedding(
|
||||
|
||||
model = CHOSEN_LITE_LLM_EMBEDDING_MODEL, # add `openai/` prefix to model so litellm knows to route to OpenAI
|
||||
|
||||
api_key=bearer_token,
|
||||
|
||||
api_base=litellm_proxy_endpoint, # set API Base of your Custom OpenAI Endpoint
|
||||
|
||||
input=["good morning from litellm"],
|
||||
|
||||
api_version='2023-07-01-preview'
|
||||
|
||||
)
|
||||
|
||||
print('================================================')
|
||||
|
||||
print(len(response.data[0]['embedding']))
|
||||
|
||||
```
|
|
@ -1,4 +1,4 @@
|
|||
# VertexAI - Google
|
||||
# VertexAI - Google [Gemini]
|
||||
|
||||
<a target="_blank" href="https://colab.research.google.com/github/BerriAI/litellm/blob/main/cookbook/liteLLM_VertextAI_Example.ipynb">
|
||||
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
||||
|
@ -10,6 +10,16 @@
|
|||
* run `gcloud auth application-default login` See [Google Cloud Docs](https://cloud.google.com/docs/authentication/external/set-up-adc)
|
||||
* Alternatively you can set `application_default_credentials.json`
|
||||
|
||||
|
||||
## Sample Usage
|
||||
```python
|
||||
import litellm
|
||||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||
litellm.vertex_location = "us-central1" # proj location
|
||||
|
||||
response = completion(model="gemini-pro", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
|
||||
```
|
||||
|
||||
## Set Vertex Project & Vertex Location
|
||||
All calls using Vertex AI require the following parameters:
|
||||
* Your Project ID
|
||||
|
@ -37,13 +47,50 @@ os.environ["VERTEXAI_LOCATION"] = "us-central1 # Your Location
|
|||
litellm.vertex_location = "us-central1 # Your Location
|
||||
```
|
||||
|
||||
## Sample Usage
|
||||
## Gemini Pro
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| gemini-pro | `completion('gemini-pro', messages)` |
|
||||
|
||||
## Gemini Pro Vision
|
||||
| Model Name | Function Call |
|
||||
|------------------|--------------------------------------|
|
||||
| gemini-pro-vision | `completion('gemini-pro-vision', messages)` |
|
||||
|
||||
#### Using Gemini Pro Vision
|
||||
|
||||
Call `gemini-pro-vision` in the same input/output format as OpenAI [`gpt-4-vision`](https://docs.litellm.ai/docs/providers/openai#openai-vision-models)
|
||||
|
||||
LiteLLM Supports the following image types passed in `url`
|
||||
- Images with Cloud Storage URIs - gs://cloud-samples-data/generative-ai/image/boats.jpeg
|
||||
- Images with direct links - https://storage.googleapis.com/github-repo/img/gemini/intro/landmark3.jpg
|
||||
- Videos with Cloud Storage URIs - https://storage.googleapis.com/github-repo/img/gemini/multimodality_usecases_overview/pixel8.mp4
|
||||
|
||||
**Example Request**
|
||||
```python
|
||||
import litellm
|
||||
litellm.vertex_project = "hardy-device-38811" # Your Project ID
|
||||
litellm.vertex_location = "us-central1" # proj location
|
||||
|
||||
response = completion(model="chat-bison", messages=[{"role": "user", "content": "write code for saying hi from LiteLLM"}])
|
||||
response = litellm.completion(
|
||||
model = "vertex_ai/gemini-pro-vision",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Whats in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Chat Models
|
||||
|
|
|
@ -1,20 +1,24 @@
|
|||
# Caching
|
||||
Cache LLM Responses
|
||||
|
||||
## Quick Start
|
||||
Caching can be enabled by adding the `cache` key in the `config.yaml`
|
||||
#### Step 1: Add `cache` to the config.yaml
|
||||
### Step 1: Add `cache` to the config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: text-embedding-ada-002
|
||||
litellm_params:
|
||||
model: text-embedding-ada-002
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
```
|
||||
|
||||
#### Step 2: Add Redis Credentials to .env
|
||||
### Step 2: Add Redis Credentials to .env
|
||||
Set either `REDIS_URL` or the `REDIS_HOST` in your os environment, to enable caching.
|
||||
|
||||
```shell
|
||||
|
@ -32,12 +36,12 @@ REDIS_<redis-kwarg-name> = ""
|
|||
```
|
||||
|
||||
[**See how it's read from the environment**](https://github.com/BerriAI/litellm/blob/4d7ff1b33b9991dcf38d821266290631d9bcd2dd/litellm/_redis.py#L40)
|
||||
#### Step 3: Run proxy with config
|
||||
### Step 3: Run proxy with config
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
#### Using Caching
|
||||
## Using Caching - /chat/completions
|
||||
Send the same request twice:
|
||||
```shell
|
||||
curl http://0.0.0.0:8000/v1/chat/completions \
|
||||
|
@ -57,9 +61,51 @@ curl http://0.0.0.0:8000/v1/chat/completions \
|
|||
}'
|
||||
```
|
||||
|
||||
#### Control caching per completion request
|
||||
## Using Caching - /embeddings
|
||||
Send the same request twice:
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
|
||||
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Advanced
|
||||
### Set Cache Params on config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: text-embedding-ada-002
|
||||
litellm_params:
|
||||
model: text-embedding-ada-002
|
||||
|
||||
litellm_settings:
|
||||
set_verbose: True
|
||||
cache: True # set cache responses to True, litellm defaults to using a redis cache
|
||||
cache_params: # cache_params are optional
|
||||
type: "redis" # The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
||||
host: "localhost" # The host address for the Redis cache. Required if type is "redis".
|
||||
port: 6379 # The port number for the Redis cache. Required if type is "redis".
|
||||
password: "your_password" # The password for the Redis cache. Required if type is "redis".
|
||||
|
||||
# Optional configurations
|
||||
supported_call_types: ["acompletion", "completion", "embedding", "aembedding"] # defaults to all litellm call types
|
||||
```
|
||||
|
||||
### Override caching per `chat/completions` request
|
||||
Caching can be switched on/off per `/chat/completions` request
|
||||
- Caching **on** for completion - pass `caching=True`:
|
||||
- Caching **on** for individual completion - pass `caching=True`:
|
||||
```shell
|
||||
curl http://0.0.0.0:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -70,7 +116,7 @@ Caching can be switched on/off per `/chat/completions` request
|
|||
"caching": true
|
||||
}'
|
||||
```
|
||||
- Caching **off** for completion - pass `caching=False`:
|
||||
- Caching **off** for individual completion - pass `caching=False`:
|
||||
```shell
|
||||
curl http://0.0.0.0:8000/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
|
@ -80,4 +126,29 @@ Caching can be switched on/off per `/chat/completions` request
|
|||
"temperature": 0.7,
|
||||
"caching": false
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
### Override caching per `/embeddings` request
|
||||
|
||||
Caching can be switched on/off per `/embeddings` request
|
||||
- Caching **on** for embedding - pass `caching=True`:
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": ["write a litellm poem"],
|
||||
"caching": true
|
||||
}'
|
||||
```
|
||||
- Caching **off** for completion - pass `caching=False`:
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": ["write a litellm poem"],
|
||||
"caching": false
|
||||
}'
|
||||
```
|
78
docs/my-website/docs/proxy/call_hooks.md
Normal file
78
docs/my-website/docs/proxy/call_hooks.md
Normal file
|
@ -0,0 +1,78 @@
|
|||
# Modify Incoming Data
|
||||
|
||||
Modify data just before making litellm completion calls call on proxy
|
||||
|
||||
See a complete example with our [parallel request rate limiter](https://github.com/BerriAI/litellm/blob/main/litellm/proxy/hooks/parallel_request_limiter.py)
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. In your Custom Handler add a new `async_pre_call_hook` function
|
||||
|
||||
This function is called just before a litellm completion call is made, and allows you to modify the data going into the litellm call [**See Code**](https://github.com/BerriAI/litellm/blob/589a6ca863000ba8e92c897ba0f776796e7a5904/litellm/proxy/proxy_server.py#L1000)
|
||||
|
||||
```python
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
|
||||
# This file includes the custom callbacks for LiteLLM Proxy
|
||||
# Once defined, these can be passed in proxy_config.yaml
|
||||
class MyCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
#### ASYNC ####
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
data["model"] = "my-new-model"
|
||||
return data
|
||||
|
||||
proxy_handler_instance = MyCustomHandler()
|
||||
```
|
||||
|
||||
2. Add this file to your proxy config
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
|
||||
litellm_settings:
|
||||
callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
```
|
||||
|
||||
3. Start the server + test the request
|
||||
|
||||
```shell
|
||||
$ litellm /path/to/config.yaml
|
||||
```
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--data ' {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "good morning good sir"
|
||||
}
|
||||
],
|
||||
"user": "ishaan-app",
|
||||
"temperature": 0.2
|
||||
}'
|
||||
```
|
||||
|
|
@ -1,4 +1,11 @@
|
|||
# Deploying LiteLLM Proxy
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# 🐳 Docker, Deploying LiteLLM Proxy
|
||||
|
||||
## Dockerfile
|
||||
|
||||
You can find the Dockerfile to build litellm proxy [here](https://github.com/BerriAI/litellm/blob/main/Dockerfile)
|
||||
|
||||
## Quick Start Docker Image: Github Container Registry
|
||||
|
||||
|
@ -7,12 +14,12 @@ See the latest available ghcr docker image here:
|
|||
https://github.com/berriai/litellm/pkgs/container/litellm
|
||||
|
||||
```shell
|
||||
docker pull ghcr.io/berriai/litellm:main-v1.10.1
|
||||
docker pull ghcr.io/berriai/litellm:main-v1.12.3
|
||||
```
|
||||
|
||||
### Run the Docker Image
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0
|
||||
docker run ghcr.io/berriai/litellm:main-v1.12.3
|
||||
```
|
||||
|
||||
#### Run the Docker Image with LiteLLM CLI args
|
||||
|
@ -21,12 +28,12 @@ See all supported CLI args [here](https://docs.litellm.ai/docs/proxy/cli):
|
|||
|
||||
Here's how you can run the docker image and pass your config to `litellm`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --config your_config.yaml
|
||||
docker run ghcr.io/berriai/litellm:main-v1.12.3 --config your_config.yaml
|
||||
```
|
||||
|
||||
Here's how you can run the docker image and start litellm on port 8002 with `num_workers=8`
|
||||
```shell
|
||||
docker run ghcr.io/berriai/litellm:main-v1.10.0 --port 8002 --num_workers 8
|
||||
docker run ghcr.io/berriai/litellm:main-v1.12.3 --port 8002 --num_workers 8
|
||||
```
|
||||
|
||||
#### Run the Docker Image using docker compose
|
||||
|
@ -42,6 +49,10 @@ Here's an example `docker-compose.yml` file
|
|||
version: "3.9"
|
||||
services:
|
||||
litellm:
|
||||
build:
|
||||
context: .
|
||||
args:
|
||||
target: runtime
|
||||
image: ghcr.io/berriai/litellm:main
|
||||
ports:
|
||||
- "8000:8000" # Map the container port to the host, change the host port if necessary
|
||||
|
@ -74,6 +85,26 @@ Your LiteLLM container should be running now on the defined port e.g. `8000`.
|
|||
<iframe width="840" height="500" src="https://www.loom.com/embed/805964b3c8384b41be180a61442389a3" frameborder="0" webkitallowfullscreen mozallowfullscreen allowfullscreen></iframe>
|
||||
|
||||
|
||||
## Deploy on Google Cloud Run
|
||||
**Click the button** to deploy to Google Cloud Run
|
||||
|
||||
[](https://deploy.cloud.run/?git_repo=https://github.com/BerriAI/litellm)
|
||||
|
||||
#### Testing your deployed proxy
|
||||
**Assuming the required keys are set as Environment Variables**
|
||||
|
||||
https://litellm-7yjrj3ha2q-uc.a.run.app is our example proxy, substitute it with your deployed cloud run app
|
||||
|
||||
```shell
|
||||
curl https://litellm-7yjrj3ha2q-uc.a.run.app/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "Say this is a test!"}],
|
||||
"temperature": 0.7
|
||||
}'
|
||||
```
|
||||
|
||||
## LiteLLM Proxy Performance
|
||||
|
||||
LiteLLM proxy has been load tested to handle 1500 req/s.
|
||||
|
|
244
docs/my-website/docs/proxy/embedding.md
Normal file
244
docs/my-website/docs/proxy/embedding.md
Normal file
|
@ -0,0 +1,244 @@
|
|||
import Image from '@theme/IdealImage';
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# Embeddings - `/embeddings`
|
||||
|
||||
See supported Embedding Providers & Models [here](https://docs.litellm.ai/docs/embedding/supported_embedding)
|
||||
|
||||
|
||||
## Quick start
|
||||
Here's how to route between GPT-J embedding (sagemaker endpoint), Amazon Titan embedding (Bedrock) and Azure OpenAI embedding on the proxy server:
|
||||
|
||||
1. Set models in your config.yaml
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: sagemaker-embeddings
|
||||
litellm_params:
|
||||
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
||||
- model_name: amazon-embeddings
|
||||
litellm_params:
|
||||
model: "bedrock/amazon.titan-embed-text-v1"
|
||||
- model_name: azure-embeddings
|
||||
litellm_params:
|
||||
model: "azure/azure-embedding-model"
|
||||
api_base: "os.environ/AZURE_API_BASE" # os.getenv("AZURE_API_BASE")
|
||||
api_key: "os.environ/AZURE_API_KEY" # os.getenv("AZURE_API_KEY")
|
||||
api_version: "2023-07-01-preview"
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234 # [OPTIONAL] if set all calls to proxy will require either this key or a valid generated token
|
||||
```
|
||||
|
||||
2. Start the proxy
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Test the embedding call
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/v1/embeddings' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"input": "The food was delicious and the waiter..",
|
||||
"model": "sagemaker-embeddings",
|
||||
}'
|
||||
```
|
||||
|
||||
## `/embeddings` Request Format
|
||||
Input, Output and Exceptions are mapped to the OpenAI format for all supported models
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Curl" label="Curl Request">
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
```
|
||||
</TabItem>
|
||||
<TabItem value="openai" label="OpenAI v1.0.0+">
|
||||
|
||||
```python
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
# set base_url to your proxy server
|
||||
# set api_key to send to proxy server
|
||||
client = OpenAI(api_key="<proxy-api-key>", base_url="http://0.0.0.0:8000")
|
||||
|
||||
response = openai.embeddings.create(
|
||||
input=["hello from litellm"],
|
||||
model="text-embedding-ada-002"
|
||||
)
|
||||
|
||||
print(response)
|
||||
|
||||
```
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="langchain-embedding" label="Langchain Embeddings">
|
||||
|
||||
```python
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="sagemaker-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||
|
||||
|
||||
text = "This is a test document."
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
print(f"SAGEMAKER EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="bedrock-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||
|
||||
text = "This is a test document."
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
print(f"BEDROCK EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
|
||||
embeddings = OpenAIEmbeddings(model="bedrock-titan-embeddings", openai_api_base="http://0.0.0.0:8000", openai_api_key="temp-key")
|
||||
|
||||
text = "This is a test document."
|
||||
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
print(f"TITAN EMBEDDINGS")
|
||||
print(query_result[:5])
|
||||
```
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## `/embeddings` Response Format
|
||||
|
||||
```json
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": [
|
||||
0.0023064255,
|
||||
-0.009327292,
|
||||
....
|
||||
-0.0028842222,
|
||||
],
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"model": "text-embedding-ada-002",
|
||||
"usage": {
|
||||
"prompt_tokens": 8,
|
||||
"total_tokens": 8
|
||||
}
|
||||
}
|
||||
|
||||
```
|
||||
|
||||
## Supported Models
|
||||
|
||||
See supported Embedding Providers & Models [here](https://docs.litellm.ai/docs/embedding/supported_embedding)
|
||||
|
||||
#### Create Config.yaml
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="Hugging Face emb" label="Hugging Face Embeddings">
|
||||
LiteLLM Proxy supports all <a href="https://huggingface.co/models?pipeline_tag=feature-extraction">Feature-Extraction Embedding models</a>.
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: deployed-codebert-base
|
||||
litellm_params:
|
||||
# send request to deployed hugging face inference endpoint
|
||||
model: huggingface/microsoft/codebert-base # add huggingface prefix so it routes to hugging face
|
||||
api_key: hf_LdS # api key for hugging face inference endpoint
|
||||
api_base: https://uysneno1wv2wd4lw.us-east-1.aws.endpoints.huggingface.cloud # your hf inference endpoint
|
||||
- model_name: codebert-base
|
||||
litellm_params:
|
||||
# no api_base set, sends request to hugging face free inference api https://api-inference.huggingface.co/models/
|
||||
model: huggingface/microsoft/codebert-base # add huggingface prefix so it routes to hugging face
|
||||
api_key: hf_LdS # api key for hugging face
|
||||
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="azure" label="Azure OpenAI Embeddings">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: azure-embedding-model # model group
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model # model name for litellm.embedding(model=azure/azure-embedding-model) call
|
||||
api_base: your-azure-api-base
|
||||
api_key: your-api-key
|
||||
api_version: 2023-07-01-preview
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="openai" label="OpenAI Embeddings">
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: text-embedding-ada-002 # model group
|
||||
litellm_params:
|
||||
model: text-embedding-ada-002 # model name for litellm.embedding(model=text-embedding-ada-002)
|
||||
api_key: your-api-key-1
|
||||
- model_name: text-embedding-ada-002
|
||||
litellm_params:
|
||||
model: text-embedding-ada-002
|
||||
api_key: your-api-key-2
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="openai emb" label="OpenAI Compatible Embeddings">
|
||||
|
||||
<p>Use this for calling <a href="https://github.com/xorbitsai/inference">/embedding endpoints on OpenAI Compatible Servers</a>.</p>
|
||||
|
||||
**Note add `openai/` prefix to `litellm_params`: `model` so litellm knows to route to OpenAI**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: text-embedding-ada-002 # model group
|
||||
litellm_params:
|
||||
model: openai/<your-model-name> # model name for litellm.embedding(model=text-embedding-ada-002)
|
||||
api_base: <model-api-base>
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
#### Start Proxy
|
||||
```shell
|
||||
litellm --config config.yaml
|
||||
```
|
||||
|
||||
#### Make Request
|
||||
Sends Request to `deployed-codebert-base`
|
||||
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/embeddings' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "deployed-codebert-base",
|
||||
"input": ["write a litellm poem"]
|
||||
}'
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
62
docs/my-website/docs/proxy/health.md
Normal file
62
docs/my-website/docs/proxy/health.md
Normal file
|
@ -0,0 +1,62 @@
|
|||
# Health Checks
|
||||
Use this to health check all LLMs defined in your config.yaml
|
||||
|
||||
## Summary
|
||||
|
||||
The proxy exposes:
|
||||
* a /health endpoint which returns the health of the LLM APIs
|
||||
* a /test endpoint which makes a ping to the litellm server
|
||||
|
||||
#### Request
|
||||
Make a GET Request to `/health` on the proxy
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/health'
|
||||
```
|
||||
|
||||
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
|
||||
```
|
||||
litellm --health
|
||||
```
|
||||
#### Response
|
||||
```shell
|
||||
{
|
||||
"healthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
|
||||
},
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
}
|
||||
],
|
||||
"unhealthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com/"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## Background Health Checks
|
||||
|
||||
You can enable model health checks being run in the background, to prevent each model from being queried too frequently via `/health`.
|
||||
|
||||
Here's how to use it:
|
||||
1. in the config.yaml add:
|
||||
```
|
||||
general_settings:
|
||||
background_health_checks: True # enable background health checks
|
||||
health_check_interval: 300 # frequency of background health checks
|
||||
```
|
||||
|
||||
2. Start server
|
||||
```
|
||||
$ litellm /path/to/config.yaml
|
||||
```
|
||||
|
||||
3. Query health endpoint:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/health'
|
||||
```
|
|
@ -72,128 +72,28 @@ curl --location 'http://0.0.0.0:8000/chat/completions' \
|
|||
'
|
||||
```
|
||||
|
||||
## Router settings on config - routing_strategy, model_group_alias
|
||||
|
||||
## Fallbacks + Cooldowns + Retries + Timeouts
|
||||
litellm.Router() settings can be set under `router_settings`. You can set `model_group_alias`, `routing_strategy`, `num_retries`,`timeout` . See all Router supported params [here](https://github.com/BerriAI/litellm/blob/1b942568897a48f014fa44618ec3ce54d7570a46/litellm/router.py#L64)
|
||||
|
||||
If a call fails after num_retries, fall back to another model group.
|
||||
|
||||
If the error is a context window exceeded error, fall back to a larger model group (if given).
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
||||
|
||||
**Set via config**
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: zephyr-beta
|
||||
litellm_params:
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||
api_base: http://0.0.0.0:8001
|
||||
- model_name: zephyr-beta
|
||||
litellm_params:
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||
api_base: http://0.0.0.0:8002
|
||||
- model_name: zephyr-beta
|
||||
litellm_params:
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||
api_base: http://0.0.0.0:8003
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: <my-openai-key>
|
||||
- model_name: gpt-3.5-turbo-16k
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-16k
|
||||
api_key: <my-openai-key>
|
||||
|
||||
litellm_settings:
|
||||
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
|
||||
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
|
||||
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
|
||||
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
|
||||
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
||||
```
|
||||
|
||||
**Set dynamically**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "zephyr-beta",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||
"num_retries": 2,
|
||||
"timeout": 10
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
## Custom Timeouts, Stream Timeouts - Per Model
|
||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||
Example config with `router_settings`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-eu
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: <your-key>
|
||||
timeout: 0.1 # timeout in (seconds)
|
||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
max_retries: 5
|
||||
model: azure/<your-deployment-name>
|
||||
api_base: <your-azure-endpoint>
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6 # Rate limit for this deployment: in requests per minute (rpm)
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key:
|
||||
timeout: 0.1 # timeout in (seconds)
|
||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
max_retries: 5
|
||||
|
||||
```
|
||||
|
||||
#### Start Proxy
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
|
||||
|
||||
## Health Check LLMs on Proxy
|
||||
Use this to health check all LLMs defined in your config.yaml
|
||||
#### Request
|
||||
Make a GET Request to `/health` on the proxy
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/health'
|
||||
```
|
||||
|
||||
You can also run `litellm -health` it makes a `get` request to `http://0.0.0.0:8000/health` for you
|
||||
```
|
||||
litellm --health
|
||||
```
|
||||
#### Response
|
||||
```shell
|
||||
{
|
||||
"healthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-canada-berri992.openai.azure.com/"
|
||||
},
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||
}
|
||||
],
|
||||
"unhealthy_endpoints": [
|
||||
{
|
||||
"model": "azure/gpt-35-turbo",
|
||||
"api_base": "https://openai-france-1234.openai.azure.com/"
|
||||
}
|
||||
]
|
||||
}
|
||||
api_key: <your-azure-api-key>
|
||||
rpm: 6
|
||||
router_settings:
|
||||
model_group_alias: {"gpt-4": "gpt-3.5-turbo"} # all requests with `gpt-4` will be routed to models with `gpt-3.5-turbo`
|
||||
routing_strategy: least-busy # Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"]
|
||||
num_retries: 2
|
||||
timeout: 30 # 30 seconds
|
||||
```
|
|
@ -1,5 +1,8 @@
|
|||
# Logging - Custom Callbacks, OpenTelemetry, Langfuse
|
||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry
|
||||
import Image from '@theme/IdealImage';
|
||||
|
||||
# Logging - Custom Callbacks, OpenTelemetry, Langfuse, Sentry
|
||||
|
||||
Log Proxy Input, Output, Exceptions using Custom Callbacks, Langfuse, OpenTelemetry, LangFuse, DynamoDB
|
||||
|
||||
## Custom Callback Class [Async]
|
||||
Use this when you want to run custom callbacks in `python`
|
||||
|
@ -486,3 +489,166 @@ litellm --test
|
|||
Expected output on Langfuse
|
||||
|
||||
<Image img={require('../../img/langfuse_small.png')} />
|
||||
|
||||
## Logging Proxy Input/Output - DynamoDB
|
||||
|
||||
We will use the `--config` to set
|
||||
- `litellm.success_callback = ["dynamodb"]`
|
||||
- `litellm.dynamodb_table_name = "your-table-name"`
|
||||
|
||||
This will log all successfull LLM calls to DynamoDB
|
||||
|
||||
**Step 1** Set AWS Credentials in .env
|
||||
|
||||
```shell
|
||||
AWS_ACCESS_KEY_ID = ""
|
||||
AWS_SECRET_ACCESS_KEY = ""
|
||||
AWS_REGION_NAME = ""
|
||||
```
|
||||
|
||||
**Step 2**: Create a `config.yaml` file and set `litellm_settings`: `success_callback`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
success_callback: ["dynamodb"]
|
||||
dynamodb_table_name: your-table-name
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
```shell
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "Azure OpenAI GPT-4 East",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
Your logs should be available on DynamoDB
|
||||
|
||||
#### Data Logged to DynamoDB /chat/completions
|
||||
|
||||
```json
|
||||
{
|
||||
"id": {
|
||||
"S": "chatcmpl-8W15J4480a3fAQ1yQaMgtsKJAicen"
|
||||
},
|
||||
"call_type": {
|
||||
"S": "acompletion"
|
||||
},
|
||||
"endTime": {
|
||||
"S": "2023-12-15 17:25:58.424118"
|
||||
},
|
||||
"messages": {
|
||||
"S": "[{'role': 'user', 'content': 'This is a test'}]"
|
||||
},
|
||||
"metadata": {
|
||||
"S": "{}"
|
||||
},
|
||||
"model": {
|
||||
"S": "gpt-3.5-turbo"
|
||||
},
|
||||
"modelParameters": {
|
||||
"S": "{'temperature': 0.7, 'max_tokens': 100, 'user': 'ishaan-2'}"
|
||||
},
|
||||
"response": {
|
||||
"S": "ModelResponse(id='chatcmpl-8W15J4480a3fAQ1yQaMgtsKJAicen', choices=[Choices(finish_reason='stop', index=0, message=Message(content='Great! What can I assist you with?', role='assistant'))], created=1702641357, model='gpt-3.5-turbo-0613', object='chat.completion', system_fingerprint=None, usage=Usage(completion_tokens=9, prompt_tokens=11, total_tokens=20))"
|
||||
},
|
||||
"startTime": {
|
||||
"S": "2023-12-15 17:25:56.047035"
|
||||
},
|
||||
"usage": {
|
||||
"S": "Usage(completion_tokens=9, prompt_tokens=11, total_tokens=20)"
|
||||
},
|
||||
"user": {
|
||||
"S": "ishaan-2"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Data logged to DynamoDB /embeddings
|
||||
|
||||
```json
|
||||
{
|
||||
"id": {
|
||||
"S": "4dec8d4d-4817-472d-9fc6-c7a6153eb2ca"
|
||||
},
|
||||
"call_type": {
|
||||
"S": "aembedding"
|
||||
},
|
||||
"endTime": {
|
||||
"S": "2023-12-15 17:25:59.890261"
|
||||
},
|
||||
"messages": {
|
||||
"S": "['hi']"
|
||||
},
|
||||
"metadata": {
|
||||
"S": "{}"
|
||||
},
|
||||
"model": {
|
||||
"S": "text-embedding-ada-002"
|
||||
},
|
||||
"modelParameters": {
|
||||
"S": "{'user': 'ishaan-2'}"
|
||||
},
|
||||
"response": {
|
||||
"S": "EmbeddingResponse(model='text-embedding-ada-002-v2', data=[{'embedding': [-0.03503197431564331, -0.020601635798811913, -0.015375726856291294,
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
## Logging Proxy Input/Output - Sentry
|
||||
|
||||
If api calls fail (llm/database) you can log those to Sentry:
|
||||
|
||||
**Step 1** Install Sentry
|
||||
```shell
|
||||
pip install --upgrade sentry-sdk
|
||||
```
|
||||
|
||||
**Step 2**: Save your Sentry_DSN and add `litellm_settings`: `failure_callback`
|
||||
```shell
|
||||
export SENTRY_DSN="your-sentry-dsn"
|
||||
```
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
litellm_settings:
|
||||
# other settings
|
||||
failure_callback: ["sentry"]
|
||||
general_settings:
|
||||
database_url: "my-bad-url" # set a fake url to trigger a sentry exception
|
||||
```
|
||||
|
||||
**Step 3**: Start the proxy, make a test request
|
||||
|
||||
Start proxy
|
||||
```shell
|
||||
litellm --config config.yaml --debug
|
||||
```
|
||||
|
||||
Test Request
|
||||
```
|
||||
litellm --test
|
||||
```
|
||||
|
|
89
docs/my-website/docs/proxy/reliability.md
Normal file
89
docs/my-website/docs/proxy/reliability.md
Normal file
|
@ -0,0 +1,89 @@
|
|||
# Fallbacks, Retries, Timeouts, Cooldowns
|
||||
|
||||
If a call fails after num_retries, fall back to another model group.
|
||||
|
||||
If the error is a context window exceeded error, fall back to a larger model group (if given).
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/main/litellm/router.py)
|
||||
|
||||
**Set via config**
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: zephyr-beta
|
||||
litellm_params:
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||
api_base: http://0.0.0.0:8001
|
||||
- model_name: zephyr-beta
|
||||
litellm_params:
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||
api_base: http://0.0.0.0:8002
|
||||
- model_name: zephyr-beta
|
||||
litellm_params:
|
||||
model: huggingface/HuggingFaceH4/zephyr-7b-beta
|
||||
api_base: http://0.0.0.0:8003
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: <my-openai-key>
|
||||
- model_name: gpt-3.5-turbo-16k
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo-16k
|
||||
api_key: <my-openai-key>
|
||||
|
||||
litellm_settings:
|
||||
num_retries: 3 # retry call 3 times on each model_name (e.g. zephyr-beta)
|
||||
request_timeout: 10 # raise Timeout error if call takes longer than 10s. Sets litellm.request_timeout
|
||||
fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo"]}] # fallback to gpt-3.5-turbo if call fails num_retries
|
||||
context_window_fallbacks: [{"zephyr-beta": ["gpt-3.5-turbo-16k"]}, {"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}] # fallback to gpt-3.5-turbo-16k if context window error
|
||||
allowed_fails: 3 # cooldown model if it fails > 1 call in a minute.
|
||||
```
|
||||
|
||||
**Set dynamically**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:8000/chat/completions' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data ' {
|
||||
"model": "zephyr-beta",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
"fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||
"context_window_fallbacks": [{"zephyr-beta": ["gpt-3.5-turbo"]}],
|
||||
"num_retries": 2,
|
||||
"timeout": 10
|
||||
}
|
||||
'
|
||||
```
|
||||
|
||||
## Custom Timeouts, Stream Timeouts - Per Model
|
||||
For each model you can set `timeout` & `stream_timeout` under `litellm_params`
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-eu
|
||||
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||
api_key: <your-key>
|
||||
timeout: 0.1 # timeout in (seconds)
|
||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
max_retries: 5
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/gpt-turbo-small-ca
|
||||
api_base: https://my-endpoint-canada-berri992.openai.azure.com/
|
||||
api_key:
|
||||
timeout: 0.1 # timeout in (seconds)
|
||||
stream_timeout: 0.01 # timeout for stream requests (seconds)
|
||||
max_retries: 5
|
||||
|
||||
```
|
||||
|
||||
#### Start Proxy
|
||||
```shell
|
||||
$ litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
# [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy.md)] Local OpenAI Proxy Server
|
||||
# [OLD PROXY 👉 [**NEW** proxy here](./simple_proxy)] Local OpenAI Proxy Server
|
||||
|
||||
A fast, and lightweight OpenAI-compatible server to call 100+ LLM APIs.
|
||||
|
||||
:::info
|
||||
|
||||
Docs outdated. New docs 👉 [here](./simple_proxy.md)
|
||||
Docs outdated. New docs 👉 [here](./simple_proxy)
|
||||
|
||||
:::
|
||||
|
||||
|
|
|
@ -366,6 +366,63 @@ router = Router(model_list: Optional[list] = None,
|
|||
cache_responses=True)
|
||||
```
|
||||
|
||||
## Caching across model groups
|
||||
|
||||
If you want to cache across 2 different model groups (e.g. azure deployments, and openai), use caching groups.
|
||||
|
||||
```python
|
||||
import litellm, asyncio, time
|
||||
from litellm import Router
|
||||
|
||||
# set os env
|
||||
os.environ["OPENAI_API_KEY"] = ""
|
||||
os.environ["AZURE_API_KEY"] = ""
|
||||
os.environ["AZURE_API_BASE"] = ""
|
||||
os.environ["AZURE_API_VERSION"] = ""
|
||||
|
||||
async def test_acompletion_caching_on_router_caching_groups():
|
||||
# tests acompletion + caching on router
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "openai-gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "azure-gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||
]
|
||||
start_time = time.time()
|
||||
router = Router(model_list=model_list,
|
||||
cache_responses=True,
|
||||
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
|
||||
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response1: {response1}")
|
||||
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
assert response1.id == response2.id
|
||||
assert len(response1.choices[0].message.content) > 0
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
asyncio.run(test_acompletion_caching_on_router_caching_groups())
|
||||
```
|
||||
|
||||
#### Default litellm.completion/embedding params
|
||||
|
||||
You can also set default params for litellm completion/embedding calls. Here's how to do that:
|
||||
|
@ -391,200 +448,3 @@ print(f"response: {response}")
|
|||
## Deploy Router
|
||||
|
||||
If you want a server to load balance across different LLM APIs, use our [OpenAI Proxy Server](./simple_proxy#load-balancing---multiple-instances-of-1-model)
|
||||
|
||||
## Queuing (Beta)
|
||||
|
||||
**Never fail a request due to rate limits**
|
||||
|
||||
The LiteLLM Queuing endpoints can handle 100+ req/s. We use Celery workers to process requests.
|
||||
|
||||
:::info
|
||||
|
||||
This is pretty new, and might have bugs. Any contributions to improving our implementation are welcome
|
||||
|
||||
:::
|
||||
|
||||
|
||||
[**See Code**](https://github.com/BerriAI/litellm/blob/fbf9cab5b9e35df524e2c9953180c58d92e4cd97/litellm/proxy/proxy_server.py#L589)
|
||||
|
||||
|
||||
### Quick Start
|
||||
|
||||
1. Add Redis credentials in a .env file
|
||||
|
||||
```python
|
||||
REDIS_HOST="my-redis-endpoint"
|
||||
REDIS_PORT="my-redis-port"
|
||||
REDIS_PASSWORD="my-redis-password" # [OPTIONAL] if self-hosted
|
||||
REDIS_USERNAME="default" # [OPTIONAL] if self-hosted
|
||||
```
|
||||
|
||||
2. Start litellm server with your model config
|
||||
|
||||
```bash
|
||||
$ litellm --config /path/to/config.yaml --use_queue
|
||||
```
|
||||
|
||||
Here's an example config for `gpt-3.5-turbo`
|
||||
|
||||
**config.yaml** (This will load balance between OpenAI + Azure endpoints)
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2 # actual model name
|
||||
api_key:
|
||||
api_version: 2023-07-01-preview
|
||||
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||
```
|
||||
|
||||
3. Test (in another window) → sends 100 simultaneous requests to the queue
|
||||
|
||||
```bash
|
||||
$ litellm --test_async --num_requests 100
|
||||
```
|
||||
|
||||
|
||||
### Available Endpoints
|
||||
- `/queue/request` - Queues a /chat/completions request. Returns a job id.
|
||||
- `/queue/response/{id}` - Returns the status of a job. If completed, returns the response as well. Potential status's are: `queued` and `finished`.
|
||||
|
||||
|
||||
## Hosted Request Queing api.litellm.ai
|
||||
Queue your LLM API requests to ensure you're under your rate limits
|
||||
- Step 1: Step 1 Add a config to the proxy, generate a temp key
|
||||
- Step 2: Queue a request to the proxy, using your generated_key
|
||||
- Step 3: Poll the request
|
||||
|
||||
|
||||
### Step 1 Add a config to the proxy, generate a temp key
|
||||
```python
|
||||
import requests
|
||||
import time
|
||||
import os
|
||||
|
||||
# Set the base URL as needed
|
||||
base_url = "https://api.litellm.ai"
|
||||
|
||||
# Step 1 Add a config to the proxy, generate a temp key
|
||||
# use the same model_name to load balance
|
||||
config = {
|
||||
"model_list": [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"api_key": os.environ['OPENAI_API_KEY'],
|
||||
}
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "",
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"api_version": "2023-07-01-preview"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{base_url}/key/generate",
|
||||
json={
|
||||
"config": config,
|
||||
"duration": "30d" # default to 30d, set it to 30m if you want a temp 30 minute key
|
||||
},
|
||||
headers={
|
||||
"Authorization": "Bearer sk-hosted-litellm" # this is the key to use api.litellm.ai
|
||||
}
|
||||
)
|
||||
|
||||
print("\nresponse from generating key", response.text)
|
||||
print("\n json response from gen key", response.json())
|
||||
|
||||
generated_key = response.json()["key"]
|
||||
print("\ngenerated key for proxy", generated_key)
|
||||
```
|
||||
|
||||
#### Output
|
||||
```shell
|
||||
response from generating key {"key":"sk-...,"expires":"2023-12-22T03:43:57.615000+00:00"}
|
||||
```
|
||||
|
||||
### Step 2: Queue a request to the proxy, using your generated_key
|
||||
```python
|
||||
print("Creating a job on the proxy")
|
||||
job_response = requests.post(
|
||||
url=f"{base_url}/queue/request",
|
||||
json={
|
||||
'model': 'gpt-3.5-turbo',
|
||||
'messages': [
|
||||
{'role': 'system', 'content': f'You are a helpful assistant. What is your name'},
|
||||
],
|
||||
},
|
||||
headers={
|
||||
"Authorization": f"Bearer {generated_key}"
|
||||
}
|
||||
)
|
||||
print(job_response.status_code)
|
||||
print(job_response.text)
|
||||
print("\nResponse from creating job", job_response.text)
|
||||
job_response = job_response.json()
|
||||
job_id = job_response["id"]
|
||||
polling_url = job_response["url"]
|
||||
polling_url = f"{base_url}{polling_url}"
|
||||
print("\nCreated Job, Polling Url", polling_url)
|
||||
```
|
||||
|
||||
#### Output
|
||||
```shell
|
||||
Response from creating job
|
||||
{"id":"0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7","url":"/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7","eta":5,"status":"queued"}
|
||||
```
|
||||
|
||||
### Step 3: Poll the request
|
||||
```python
|
||||
while True:
|
||||
try:
|
||||
print("\nPolling URL", polling_url)
|
||||
polling_response = requests.get(
|
||||
url=polling_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {generated_key}"
|
||||
}
|
||||
)
|
||||
print("\nResponse from polling url", polling_response.text)
|
||||
polling_response = polling_response.json()
|
||||
status = polling_response.get("status", None)
|
||||
if status == "finished":
|
||||
llm_response = polling_response["result"]
|
||||
print("LLM Response")
|
||||
print(llm_response)
|
||||
break
|
||||
time.sleep(0.5)
|
||||
except Exception as e:
|
||||
print("got exception in polling", e)
|
||||
break
|
||||
```
|
||||
|
||||
#### Output
|
||||
```shell
|
||||
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
|
||||
|
||||
Response from polling url {"status":"queued"}
|
||||
|
||||
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
|
||||
|
||||
Response from polling url {"status":"queued"}
|
||||
|
||||
Polling URL https://api.litellm.ai/queue/response/0e3d9e98-5d56-4d07-9cc8-c34b7e6658d7
|
||||
|
||||
Response from polling url
|
||||
{"status":"finished","result":{"id":"chatcmpl-8NYRce4IeI4NzYyodT3NNp8fk5cSW","choices":[{"finish_reason":"stop","index":0,"message":{"content":"I am an AI assistant and do not have a physical presence or personal identity. You can simply refer to me as \"Assistant.\" How may I assist you today?","role":"assistant"}}],"created":1700624639,"model":"gpt-3.5-turbo-0613","object":"chat.completion","system_fingerprint":null,"usage":{"completion_tokens":33,"prompt_tokens":17,"total_tokens":50}}}
|
||||
|
||||
```
|
|
@ -61,11 +61,13 @@ const sidebars = {
|
|||
},
|
||||
items: [
|
||||
"providers/openai",
|
||||
"providers/openai_compatible",
|
||||
"providers/azure",
|
||||
"providers/huggingface",
|
||||
"providers/ollama",
|
||||
"providers/vertex",
|
||||
"providers/palm",
|
||||
"providers/mistral",
|
||||
"providers/anthropic",
|
||||
"providers/aws_sagemaker",
|
||||
"providers/bedrock",
|
||||
|
@ -96,10 +98,14 @@ const sidebars = {
|
|||
},
|
||||
items: [
|
||||
"proxy/quick_start",
|
||||
"proxy/configs",
|
||||
"proxy/configs",
|
||||
"proxy/embedding",
|
||||
"proxy/load_balancing",
|
||||
"proxy/virtual_keys",
|
||||
"proxy/model_management",
|
||||
"proxy/reliability",
|
||||
"proxy/health",
|
||||
"proxy/call_hooks",
|
||||
"proxy/caching",
|
||||
"proxy/logging",
|
||||
"proxy/cli",
|
||||
|
@ -189,6 +195,7 @@ const sidebars = {
|
|||
slug: '/project',
|
||||
},
|
||||
items: [
|
||||
"projects/Docq.AI",
|
||||
"projects/OpenInterpreter",
|
||||
"projects/FastREPL",
|
||||
"projects/PROMPTMETHEUS",
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -10,7 +10,7 @@ success_callback: List[Union[str, Callable]] = []
|
|||
failure_callback: List[Union[str, Callable]] = []
|
||||
callbacks: List[Callable] = []
|
||||
_async_input_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Union[str, Callable]] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Callable] = [] # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
|
@ -48,6 +48,8 @@ cache: Optional[Cache] = None # cache object <- use this - https://docs.litellm.
|
|||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
_openai_completion_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
_litellm_completion_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||
_current_cost = 0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
|
@ -56,6 +58,7 @@ aclient_session: Optional[httpx.AsyncClient] = None
|
|||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
suppress_debug_info = False
|
||||
dynamodb_table_name: Optional[str] = None
|
||||
#### RELIABILITY ####
|
||||
request_timeout: Optional[float] = 6000
|
||||
num_retries: Optional[int] = None
|
||||
|
@ -107,6 +110,8 @@ open_ai_text_completion_models: List = []
|
|||
cohere_models: List = []
|
||||
anthropic_models: List = []
|
||||
openrouter_models: List = []
|
||||
vertex_language_models: List = []
|
||||
vertex_vision_models: List = []
|
||||
vertex_chat_models: List = []
|
||||
vertex_code_chat_models: List = []
|
||||
vertex_text_models: List = []
|
||||
|
@ -133,6 +138,10 @@ for key, value in model_cost.items():
|
|||
vertex_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||
vertex_code_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-language-models':
|
||||
vertex_language_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-vision-models':
|
||||
vertex_vision_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
|
||||
vertex_chat_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
|
||||
|
@ -154,7 +163,16 @@ for key, value in model_cost.items():
|
|||
openai_compatible_endpoints: List = [
|
||||
"api.perplexity.ai",
|
||||
"api.endpoints.anyscale.com/v1",
|
||||
"api.deepinfra.com/v1/openai"
|
||||
"api.deepinfra.com/v1/openai",
|
||||
"api.mistral.ai/v1"
|
||||
]
|
||||
|
||||
# this is maintained for Exception Mapping
|
||||
openai_compatible_providers: List = [
|
||||
"anyscale",
|
||||
"mistral",
|
||||
"deepinfra",
|
||||
"perplexity"
|
||||
]
|
||||
|
||||
|
||||
|
@ -266,6 +284,7 @@ model_list = (
|
|||
provider_list: List = [
|
||||
"openai",
|
||||
"custom_openai",
|
||||
"text-completion-openai",
|
||||
"cohere",
|
||||
"anthropic",
|
||||
"replicate",
|
||||
|
@ -287,6 +306,7 @@ provider_list: List = [
|
|||
"deepinfra",
|
||||
"perplexity",
|
||||
"anyscale",
|
||||
"mistral",
|
||||
"maritalk",
|
||||
"custom", # custom apis
|
||||
]
|
||||
|
@ -396,6 +416,7 @@ from .exceptions import (
|
|||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
BadRequestError,
|
||||
NotFoundError,
|
||||
RateLimitError,
|
||||
ServiceUnavailableError,
|
||||
OpenAIError,
|
||||
|
@ -404,7 +425,8 @@ from .exceptions import (
|
|||
APIError,
|
||||
Timeout,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError
|
||||
)
|
||||
from .budget_manager import BudgetManager
|
||||
from .proxy.proxy_cli import run_server
|
||||
|
|
|
@ -10,19 +10,7 @@
|
|||
import litellm
|
||||
import time, logging
|
||||
import json, traceback, ast
|
||||
from typing import Optional
|
||||
|
||||
def get_prompt(*args, **kwargs):
|
||||
# make this safe checks, it should not throw any exceptions
|
||||
if len(args) > 1:
|
||||
messages = args[1]
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
return prompt
|
||||
if "messages" in kwargs:
|
||||
messages = kwargs["messages"]
|
||||
prompt = " ".join(message["content"] for message in messages)
|
||||
return prompt
|
||||
return None
|
||||
from typing import Optional, Literal, List
|
||||
|
||||
def print_verbose(print_statement):
|
||||
try:
|
||||
|
@ -174,34 +162,36 @@ class DualCache(BaseCache):
|
|||
if self.redis_cache is not None:
|
||||
self.redis_cache.flush_cache()
|
||||
|
||||
#### LiteLLM.Completion Cache ####
|
||||
#### LiteLLM.Completion / Embedding Cache ####
|
||||
class Cache:
|
||||
def __init__(
|
||||
self,
|
||||
type="local",
|
||||
host=None,
|
||||
port=None,
|
||||
password=None,
|
||||
type: Optional[Literal["local", "redis"]] = "local",
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
supported_call_types: Optional[List[Literal["completion", "acompletion", "embedding", "aembedding"]]] = ["completion", "acompletion", "embedding", "aembedding"],
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Initializes the cache based on the given type.
|
||||
|
||||
Args:
|
||||
type (str, optional): The type of cache to initialize. Defaults to "local".
|
||||
type (str, optional): The type of cache to initialize. Can be "local" or "redis". Defaults to "local".
|
||||
host (str, optional): The host address for the Redis cache. Required if type is "redis".
|
||||
port (int, optional): The port number for the Redis cache. Required if type is "redis".
|
||||
password (str, optional): The password for the Redis cache. Required if type is "redis".
|
||||
supported_call_types (list, optional): List of call types to cache for. Defaults to cache == on for all call types.
|
||||
**kwargs: Additional keyword arguments for redis.Redis() cache
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid cache type is provided.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None. Cache is set as a litellm param
|
||||
"""
|
||||
if type == "redis":
|
||||
self.cache = RedisCache(host, port, password, **kwargs)
|
||||
self.cache: BaseCache = RedisCache(host, port, password, **kwargs)
|
||||
if type == "local":
|
||||
self.cache = InMemoryCache()
|
||||
if "cache" not in litellm.input_callback:
|
||||
|
@ -210,6 +200,7 @@ class Cache:
|
|||
litellm.success_callback.append("cache")
|
||||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
self.supported_call_types = supported_call_types # default to ["completion", "acompletion", "embedding", "aembedding"]
|
||||
|
||||
def get_cache_key(self, *args, **kwargs):
|
||||
"""
|
||||
|
@ -222,29 +213,55 @@ class Cache:
|
|||
Returns:
|
||||
str: The cache key generated from the arguments, or None if no cache key could be generated.
|
||||
"""
|
||||
cache_key =""
|
||||
cache_key = ""
|
||||
print_verbose(f"\nGetting Cache key. Kwargs: {kwargs}")
|
||||
|
||||
# for streaming, we use preset_cache_key. It's created in wrapper(), we do this because optional params like max_tokens, get transformed for bedrock -> max_new_tokens
|
||||
if kwargs.get("litellm_params", {}).get("preset_cache_key", None) is not None:
|
||||
print_verbose(f"\nReturning preset cache key: {cache_key}")
|
||||
return kwargs.get("litellm_params", {}).get("preset_cache_key", None)
|
||||
|
||||
# sort kwargs by keys, since model: [gpt-4, temperature: 0.2, max_tokens: 200] == [temperature: 0.2, max_tokens: 200, model: gpt-4]
|
||||
completion_kwargs = ["model", "messages", "temperature", "top_p", "n", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "response_format", "seed", "tools", "tool_choice"]
|
||||
for param in completion_kwargs:
|
||||
embedding_only_kwargs = ["input", "encoding_format"] # embedding kwargs = model, input, user, encoding_format. Model, user are checked in completion_kwargs
|
||||
|
||||
# combined_kwargs - NEEDS to be ordered across get_cache_key(). Do not use a set()
|
||||
combined_kwargs = completion_kwargs + embedding_only_kwargs
|
||||
for param in combined_kwargs:
|
||||
# ignore litellm params here
|
||||
if param in kwargs:
|
||||
# check if param == model and model_group is passed in, then override model with model_group
|
||||
if param == "model":
|
||||
model_group = None
|
||||
caching_group = None
|
||||
metadata = kwargs.get("metadata", None)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group")
|
||||
model_group = metadata.get("model_group", None)
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
if litellm_params is not None:
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
if metadata is not None:
|
||||
model_group = metadata.get("model_group", None)
|
||||
param_value = model_group or kwargs[param] # use model_group if it exists, else use kwargs["model"]
|
||||
caching_groups = metadata.get("caching_groups", None)
|
||||
if caching_groups:
|
||||
for group in caching_groups:
|
||||
if model_group in group:
|
||||
caching_group = group
|
||||
break
|
||||
param_value = caching_group or model_group or kwargs[param] # use caching_group, if set then model_group if it exists, else use kwargs["model"]
|
||||
else:
|
||||
if kwargs[param] is None:
|
||||
continue # ignore None params
|
||||
param_value = kwargs[param]
|
||||
cache_key+= f"{str(param)}: {str(param_value)}"
|
||||
print_verbose(f"\nCreated cache key: {cache_key}")
|
||||
return cache_key
|
||||
|
||||
def generate_streaming_content(self, content):
|
||||
|
@ -297,4 +314,9 @@ class Cache:
|
|||
result = result.model_dump_json()
|
||||
self.cache.set_cache(cache_key, result, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
traceback.print_exc()
|
||||
pass
|
||||
|
||||
async def _async_add_cache(self, result, *args, **kwargs):
|
||||
self.add_cache(result, *args, **kwargs)
|
|
@ -12,16 +12,19 @@
|
|||
from openai import (
|
||||
AuthenticationError,
|
||||
BadRequestError,
|
||||
NotFoundError,
|
||||
RateLimitError,
|
||||
APIStatusError,
|
||||
OpenAIError,
|
||||
APIError,
|
||||
APITimeoutError,
|
||||
APIConnectionError,
|
||||
APIResponseValidationError
|
||||
APIResponseValidationError,
|
||||
UnprocessableEntityError
|
||||
)
|
||||
import httpx
|
||||
|
||||
|
||||
class AuthenticationError(AuthenticationError): # type: ignore
|
||||
def __init__(self, message, llm_provider, model, response: httpx.Response):
|
||||
self.status_code = 401
|
||||
|
@ -34,6 +37,20 @@ class AuthenticationError(AuthenticationError): # type: ignore
|
|||
body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
# raise when invalid models passed, example gpt-8
|
||||
class NotFoundError(NotFoundError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
self.status_code = 404
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
class BadRequestError(BadRequestError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
self.status_code = 400
|
||||
|
@ -46,6 +63,18 @@ class BadRequestError(BadRequestError): # type: ignore
|
|||
body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class UnprocessableEntityError(UnprocessableEntityError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider, response: httpx.Response):
|
||||
self.status_code = 422
|
||||
self.message = message
|
||||
self.model = model
|
||||
self.llm_provider = llm_provider
|
||||
super().__init__(
|
||||
self.message,
|
||||
response=response,
|
||||
body=None
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
class Timeout(APITimeoutError): # type: ignore
|
||||
def __init__(self, message, model, llm_provider):
|
||||
self.status_code = 408
|
||||
|
|
|
@ -2,8 +2,9 @@
|
|||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
import requests
|
||||
import requests
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from typing import Literal
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
|
||||
|
@ -27,7 +28,12 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
|||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### ASYNC ####
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
pass
|
||||
|
||||
|
@ -37,6 +43,16 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback
|
|||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
pass
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
"""
|
||||
Control the modify incoming / outgoung data before calling the model
|
||||
"""
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: Literal["completion", "embeddings"]):
|
||||
pass
|
||||
|
||||
async def async_post_call_failure_hook(self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth):
|
||||
pass
|
||||
|
||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||
|
||||
def log_input_event(self, model, messages, kwargs, print_verbose, callback_func):
|
||||
|
|
82
litellm/integrations/dynamodb.py
Normal file
82
litellm/integrations/dynamodb.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
|
||||
dotenv.load_dotenv() # Loading env variables using dotenv
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
from litellm._logging import print_verbose
|
||||
|
||||
class DyanmoDBLogger:
|
||||
# Class variables or attributes
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
import boto3
|
||||
self.dynamodb = boto3.resource('dynamodb', region_name=os.environ["AWS_REGION_NAME"])
|
||||
if litellm.dynamodb_table_name is None:
|
||||
raise ValueError("LiteLLM Error, trying to use DynamoDB but not table name passed. Create a table and set `litellm.dynamodb_table_name=<your-table>`")
|
||||
self.table_name = litellm.dynamodb_table_name
|
||||
|
||||
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
try:
|
||||
print_verbose(
|
||||
f"DynamoDB Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
|
||||
# construct payload to send to DynamoDB
|
||||
# follows the same params as langfuse.py
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
|
||||
messages = kwargs.get("messages")
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
call_type = kwargs.get("call_type", "litellm.completion")
|
||||
usage = response_obj["usage"]
|
||||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
|
||||
# Build the initial payload
|
||||
payload = {
|
||||
"id": id,
|
||||
"call_type": call_type,
|
||||
"startTime": start_time,
|
||||
"endTime": end_time,
|
||||
"model": kwargs.get("model", ""),
|
||||
"user": kwargs.get("user", ""),
|
||||
"modelParameters": optional_params,
|
||||
"messages": messages,
|
||||
"response": response_obj,
|
||||
"usage": usage,
|
||||
"metadata": metadata
|
||||
}
|
||||
|
||||
# Ensure everything in the payload is converted to str
|
||||
for key, value in payload.items():
|
||||
try:
|
||||
payload[key] = str(value)
|
||||
except:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
|
||||
print_verbose(f"\nDynamoDB Logger - Logging payload = {payload}")
|
||||
|
||||
# put data in dyanmo DB
|
||||
table = self.dynamodb.Table(self.table_name)
|
||||
# Assuming log_data is a dictionary with log information
|
||||
response = table.put_item(Item=payload)
|
||||
|
||||
print_verbose(f"Response from DynamoDB:{str(response)}")
|
||||
|
||||
print_verbose(
|
||||
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
return response
|
||||
except:
|
||||
traceback.print_exc()
|
||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||
pass
|
|
@ -58,7 +58,7 @@ class LangFuseLogger:
|
|||
model=kwargs['model'],
|
||||
modelParameters=optional_params,
|
||||
prompt=prompt,
|
||||
completion=response_obj['choices'][0]['message'],
|
||||
completion=response_obj['choices'][0]['message'].json(),
|
||||
usage=Usage(
|
||||
prompt_tokens=response_obj['usage']['prompt_tokens'],
|
||||
completion_tokens=response_obj['usage']['completion_tokens']
|
||||
|
@ -70,6 +70,9 @@ class LangFuseLogger:
|
|||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
except:
|
||||
# traceback.print_exc()
|
||||
traceback.print_exc()
|
||||
print_verbose(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
async def _async_log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
self.log_event(kwargs, response_obj, start_time, end_time, print_verbose)
|
||||
|
|
|
@ -58,7 +58,7 @@ class LangsmithLogger:
|
|||
"inputs": {
|
||||
**new_kwargs
|
||||
},
|
||||
"outputs": response_obj,
|
||||
"outputs": response_obj.json(),
|
||||
"session_name": project_name,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
class TraceloopLogger:
|
||||
def __init__(self):
|
||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||
|
||||
from traceloop.sdk import Traceloop
|
||||
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
|
||||
self.tracer_wrapper = TracerWrapper()
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
|
|
|
@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
azure_client = client
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
response.model = "azure/" + str(response.model)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
|
@ -318,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
client=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -327,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
|
@ -372,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -391,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
|
|||
}
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
|
|
|
@ -482,7 +482,7 @@ def completion(
|
|||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response_body,
|
||||
original_response=json.dumps(response_body),
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
|
@ -552,6 +552,7 @@ def _embedding_func_single(
|
|||
## FORMAT EMBEDDING INPUT ##
|
||||
provider = model.split(".")[0]
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop("user", None) # make sure user is not passed in for bedrock call
|
||||
if provider == "amazon":
|
||||
input = input.replace(os.linesep, " ")
|
||||
data = {"inputText": input, **inference_params}
|
||||
|
@ -587,7 +588,7 @@ def _embedding_func_single(
|
|||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_body,
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
if provider == "cohere":
|
||||
response = response_body.get("embeddings")
|
||||
|
@ -650,14 +651,5 @@ def embedding(
|
|||
total_tokens=input_tokens + 0
|
||||
)
|
||||
model_response.usage = usage
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
original_response=embeddings,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -542,7 +542,7 @@ class Huggingface(BaseLLM):
|
|||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
additional_args={"complete_input_dict": data, "headers": headers, "api_base": embed_url},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
response = requests.post(
|
||||
|
@ -584,6 +584,14 @@ class Huggingface(BaseLLM):
|
|||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
elif isinstance(embedding, list) and isinstance(embedding[0], float):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding # flatten list returned from hf
|
||||
}
|
||||
)
|
||||
else:
|
||||
output_data.append(
|
||||
{
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import requests, types
|
||||
import requests, types, time
|
||||
import json
|
||||
import traceback
|
||||
from typing import Optional
|
||||
import litellm
|
||||
import httpx
|
||||
|
||||
import httpx, aiohttp, asyncio
|
||||
try:
|
||||
from async_generator import async_generator, yield_ # optional dependency
|
||||
async_generator_imported = True
|
||||
|
@ -115,6 +114,9 @@ def get_ollama_response_stream(
|
|||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
model_response=None,
|
||||
encoding=None
|
||||
):
|
||||
if api_base.endswith("/api/generate"):
|
||||
url = api_base
|
||||
|
@ -136,8 +138,19 @@ def get_ollama_response_stream(
|
|||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=None,
|
||||
additional_args={"api_base": url, "complete_input_dict": data},
|
||||
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}, "acompletion": acompletion,},
|
||||
)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
response = ollama_async_streaming(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
else:
|
||||
response = ollama_acompletion(url=url, data=data, model_response=model_response, encoding=encoding, logging_obj=logging_obj)
|
||||
return response
|
||||
|
||||
else:
|
||||
return ollama_completion_stream(url=url, data=data)
|
||||
|
||||
def ollama_completion_stream(url, data):
|
||||
session = requests.Session()
|
||||
|
||||
with session.post(url, json=data, stream=True) as resp:
|
||||
|
@ -169,41 +182,38 @@ def get_ollama_response_stream(
|
|||
traceback.print_exc()
|
||||
session.close()
|
||||
|
||||
if async_generator_imported:
|
||||
# ollama implementation
|
||||
@async_generator
|
||||
async def async_get_ollama_response_stream(
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
url = f"{api_base}/api/generate"
|
||||
|
||||
## Load Config
|
||||
config=litellm.OllamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if k not in optional_params: # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=None,
|
||||
api_key=None,
|
||||
additional_args={"api_base": url, "complete_input_dict": data},
|
||||
)
|
||||
session = requests.Session()
|
||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
||||
try:
|
||||
client = httpx.AsyncClient()
|
||||
async with client.stream(
|
||||
url=f"{url}",
|
||||
json=data,
|
||||
method="POST",
|
||||
timeout=litellm.request_timeout
|
||||
) as response:
|
||||
if response.status_code != 200:
|
||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||
|
||||
streamwrapper = litellm.CustomStreamWrapper(completion_stream=response.aiter_lines(), model=data['model'], custom_llm_provider="ollama",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
with session.post(url, json=data, stream=True) as resp:
|
||||
if resp.status_code != 200:
|
||||
raise OllamaError(status_code=resp.status_code, message=resp.text)
|
||||
for line in resp.iter_lines():
|
||||
async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
|
||||
data["stream"] = False
|
||||
try:
|
||||
timeout = aiohttp.ClientTimeout(total=600) # 10 minutes
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
resp = await session.post(url, json=data)
|
||||
|
||||
if resp.status != 200:
|
||||
text = await resp.text()
|
||||
raise OllamaError(status_code=resp.status, message=text)
|
||||
|
||||
completion_string = ""
|
||||
async for line in resp.content.iter_any():
|
||||
if line:
|
||||
try:
|
||||
json_chunk = line.decode("utf-8")
|
||||
|
@ -217,15 +227,24 @@ if async_generator_imported:
|
|||
"content": "",
|
||||
"error": j
|
||||
}
|
||||
await yield_({"choices": [{"delta": completion_obj}]})
|
||||
raise Exception(f"OllamError - {chunk}")
|
||||
if "response" in j:
|
||||
completion_obj = {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"content": j["response"],
|
||||
}
|
||||
completion_obj["content"] = j["response"]
|
||||
await yield_({"choices": [{"delta": completion_obj}]})
|
||||
completion_string = completion_string + completion_obj["content"]
|
||||
except Exception as e:
|
||||
import logging
|
||||
logging.debug(f"Error decoding JSON: {e}")
|
||||
session.close()
|
||||
traceback.print_exc()
|
||||
|
||||
## RESPONSE OBJECT
|
||||
model_response["choices"][0]["finish_reason"] = "stop"
|
||||
model_response["choices"][0]["message"]["content"] = completion_string
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "ollama/" + data['model']
|
||||
prompt_tokens = len(encoding.encode(data['prompt'])) # type: ignore
|
||||
completion_tokens = len(encoding.encode(completion_string))
|
||||
model_response["usage"] = litellm.Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens)
|
||||
return model_response
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
|
|
@ -195,23 +195,23 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
**optional_params
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
try:
|
||||
max_retries = data.pop("max_retries", 2)
|
||||
if acompletion is True:
|
||||
if optional_params.get("stream", False):
|
||||
return self.async_streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
return self.async_streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
else:
|
||||
return self.acompletion(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
return self.acompletion(data=data, headers=headers, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
elif optional_params.get("stream", False):
|
||||
return self.streaming(logging_obj=logging_obj, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
return self.streaming(logging_obj=logging_obj, headers=headers, data=data, model=model, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries)
|
||||
else:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": acompletion, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
if client is None:
|
||||
|
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_client = client
|
||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||
stringified_response = response.model_dump_json()
|
||||
logging_obj.post_call(
|
||||
input=None,
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
original_response=stringified_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
|
@ -259,6 +260,8 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
headers=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -266,16 +269,23 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_aclient = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
api_key=openai_aclient.api_key,
|
||||
additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data},
|
||||
)
|
||||
response = await openai_aclient.chat.completions.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
stringified_response = response.model_dump_json()
|
||||
logging_obj.post_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if response and hasattr(response, "text"):
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}\n\nOriginal Response: {response.text}")
|
||||
else:
|
||||
if type(e).__name__ == "ReadTimeout":
|
||||
raise OpenAIError(status_code=408, message=f"{type(e).__name__}")
|
||||
else:
|
||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||
raise e
|
||||
|
||||
def streaming(self,
|
||||
logging_obj,
|
||||
|
@ -285,12 +295,19 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_key: Optional[str]=None,
|
||||
api_base: Optional[str]=None,
|
||||
client = None,
|
||||
max_retries=None
|
||||
max_retries=None,
|
||||
headers=None
|
||||
):
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": False, "complete_input_dict": data},
|
||||
)
|
||||
response = openai_client.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
return streamwrapper
|
||||
|
@ -304,6 +321,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
headers=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -311,6 +329,13 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
openai_aclient = AsyncOpenAI(api_key=api_key, base_url=api_base, http_client=litellm.aclient_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_aclient = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data['messages'],
|
||||
api_key=api_key,
|
||||
additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data},
|
||||
)
|
||||
|
||||
response = await openai_aclient.chat.completions.create(**data)
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="openai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
|
@ -325,6 +350,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||
async def aembedding(
|
||||
self,
|
||||
input: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
|
@ -332,6 +358,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -340,9 +367,24 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
||||
return response
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
@ -367,13 +409,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -381,6 +416,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
## COMPLETION CALL
|
||||
response = openai_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
|
@ -472,12 +515,14 @@ class OpenAITextCompletion(BaseLLM):
|
|||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
|
||||
# don't send max retries to the api, if set
|
||||
optional_params.pop("max_retries", None)
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
**optional_params
|
||||
}
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
|
|
|
@ -73,8 +73,27 @@ def ollama_pt(model, messages): # https://github.com/jmorganca/ollama/blob/af4cf
|
|||
final_prompt_value="### Response:",
|
||||
messages=messages
|
||||
)
|
||||
elif "llava" in model:
|
||||
prompt = ""
|
||||
images = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
for element in message["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
prompt += element["text"]
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
images.append(image_url)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"images": images
|
||||
}
|
||||
else:
|
||||
prompt = "".join(m["content"] for m in messages)
|
||||
prompt = "".join(m["content"] if isinstance(m['content'], str) is str else "".join(m['content']) for m in messages)
|
||||
return prompt
|
||||
|
||||
def mistral_instruct_pt(messages):
|
||||
|
@ -161,6 +180,8 @@ def phind_codellama_pt(messages):
|
|||
|
||||
def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=None):
|
||||
## get the tokenizer config from huggingface
|
||||
bos_token = ""
|
||||
eos_token = ""
|
||||
if chat_template is None:
|
||||
def _get_tokenizer_config(hf_model_name):
|
||||
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
|
||||
|
@ -187,7 +208,10 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
|||
# Create a template object from the template text
|
||||
env = Environment()
|
||||
env.globals['raise_exception'] = raise_exception
|
||||
template = env.from_string(chat_template)
|
||||
try:
|
||||
template = env.from_string(chat_template)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _is_system_in_template():
|
||||
try:
|
||||
|
@ -227,8 +251,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any]=No
|
|||
new_messages.append(reformatted_messages[-1])
|
||||
rendered_text = template.render(bos_token=bos_token, eos_token=eos_token, messages=new_messages)
|
||||
return rendered_text
|
||||
except:
|
||||
raise Exception("Error rendering template")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error rendering template - {str(e)}")
|
||||
|
||||
# Anthropic template
|
||||
def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/docs/how-to-use-system-prompts
|
||||
|
@ -266,20 +290,26 @@ def claude_2_1_pt(messages: list): # format - https://docs.anthropic.com/claude/
|
|||
### TOGETHER AI
|
||||
|
||||
def get_model_info(token, model):
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}'
|
||||
}
|
||||
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||
if response.status_code == 200:
|
||||
model_info = response.json()
|
||||
for m in model_info:
|
||||
if m["name"].lower().strip() == model.strip():
|
||||
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||
return None, None
|
||||
else:
|
||||
try:
|
||||
headers = {
|
||||
'Authorization': f'Bearer {token}'
|
||||
}
|
||||
response = requests.get('https://api.together.xyz/models/info', headers=headers)
|
||||
if response.status_code == 200:
|
||||
model_info = response.json()
|
||||
for m in model_info:
|
||||
if m["name"].lower().strip() == model.strip():
|
||||
return m['config'].get('prompt_format', None), m['config'].get('chat_template', None)
|
||||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
except Exception as e: # safely fail a prompt template request
|
||||
return None, None
|
||||
|
||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||
if prompt_format is None:
|
||||
return default_pt(messages)
|
||||
|
||||
human_prompt, assistant_prompt = prompt_format.split('{prompt}')
|
||||
|
||||
if chat_template is not None:
|
||||
|
@ -397,4 +427,4 @@ def prompt_factory(model: str, messages: list, custom_llm_provider: Optional[str
|
|||
return hf_chat_template(original_model_name, messages)
|
||||
except:
|
||||
return default_pt(messages=messages) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
||||
|
||||
|
|
|
@ -232,7 +232,8 @@ def completion(
|
|||
if system_prompt is not None:
|
||||
input_data = {
|
||||
"prompt": prompt,
|
||||
"system_prompt": system_prompt
|
||||
"system_prompt": system_prompt,
|
||||
**optional_params
|
||||
}
|
||||
# Otherwise, use the prompt as is
|
||||
else:
|
||||
|
|
|
@ -158,6 +158,7 @@ def completion(
|
|||
)
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=f"{str(e)}")
|
||||
|
||||
response = response["Body"].read().decode("utf8")
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -171,10 +172,17 @@ def completion(
|
|||
completion_response = json.loads(response)
|
||||
try:
|
||||
completion_response_choices = completion_response[0]
|
||||
completion_output = ""
|
||||
if "generation" in completion_response_choices:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generation"]
|
||||
completion_output += completion_response_choices["generation"]
|
||||
elif "generated_text" in completion_response_choices:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response_choices["generated_text"]
|
||||
completion_output += completion_response_choices["generated_text"]
|
||||
|
||||
# check if the prompt template is part of output, if so - filter it out
|
||||
if completion_output.startswith(prompt) and "<s>" in prompt:
|
||||
completion_output = completion_output.replace(prompt, "", 1)
|
||||
|
||||
model_response["choices"][0]["message"]["content"] = completion_output
|
||||
except:
|
||||
raise SagemakerError(message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", status_code=500)
|
||||
|
||||
|
|
|
@ -173,10 +173,11 @@ def completion(
|
|||
message=json.dumps(completion_response["output"]), status_code=response.status_code
|
||||
)
|
||||
|
||||
if len(completion_response["output"]["choices"][0]["text"]) > 0:
|
||||
if len(completion_response["output"]["choices"][0]["text"]) >= 0:
|
||||
model_response["choices"][0]["message"]["content"] = completion_response["output"]["choices"][0]["text"]
|
||||
|
||||
## CALCULATING USAGE
|
||||
print_verbose(f"CALCULATING TOGETHERAI TOKEN USAGE. Model Response: {model_response}; model_response['choices'][0]['message'].get('content', ''): {model_response['choices'][0]['message'].get('content', None)}")
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
|
|
|
@ -4,7 +4,7 @@ from enum import Enum
|
|||
import requests
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from litellm.utils import ModelResponse, Usage
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper
|
||||
import litellm
|
||||
import httpx
|
||||
|
||||
|
@ -57,6 +57,108 @@ class VertexAIConfig():
|
|||
and not isinstance(v, (types.FunctionType, types.BuiltinFunctionType, classmethod, staticmethod))
|
||||
and v is not None}
|
||||
|
||||
def _get_image_bytes_from_url(image_url: str) -> bytes:
|
||||
try:
|
||||
response = requests.get(image_url)
|
||||
response.raise_for_status() # Raise an error for bad responses (4xx and 5xx)
|
||||
image_bytes = response.content
|
||||
return image_bytes
|
||||
except requests.exceptions.RequestException as e:
|
||||
# Handle any request exceptions (e.g., connection error, timeout)
|
||||
return b'' # Return an empty bytes object or handle the error as needed
|
||||
|
||||
|
||||
def _load_image_from_url(image_url: str):
|
||||
"""
|
||||
Loads an image from a URL.
|
||||
|
||||
Args:
|
||||
image_url (str): The URL of the image.
|
||||
|
||||
Returns:
|
||||
Image: The loaded image.
|
||||
"""
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
|
||||
image_bytes = _get_image_bytes_from_url(image_url)
|
||||
return Image.from_bytes(image_bytes)
|
||||
|
||||
def _gemini_vision_convert_messages(
|
||||
messages: list
|
||||
):
|
||||
"""
|
||||
Converts given messages for GPT-4 Vision to Gemini format.
|
||||
|
||||
Args:
|
||||
messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type:
|
||||
- If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt.
|
||||
- If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images.
|
||||
|
||||
Returns:
|
||||
tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images).
|
||||
|
||||
Raises:
|
||||
VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed.
|
||||
Exception: If any other exception occurs during the execution of the function.
|
||||
|
||||
Note:
|
||||
This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub.
|
||||
The supported MIME types for images include 'image/png' and 'image/jpeg'.
|
||||
|
||||
Examples:
|
||||
>>> messages = [
|
||||
... {"content": "Hello, world!"},
|
||||
... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]},
|
||||
... ]
|
||||
>>> _gemini_vision_convert_messages(messages)
|
||||
('Hello, world!This is a text message.', [<Part object>, <Part object>])
|
||||
"""
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
raise VertexAIError(status_code=400,message="vertexai import failed please run `pip install google-cloud-aiplatform`")
|
||||
try:
|
||||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig, Image
|
||||
|
||||
# given messages for gpt-4 vision, convert them for gemini
|
||||
# https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb
|
||||
prompt = ""
|
||||
images = []
|
||||
for message in messages:
|
||||
if isinstance(message["content"], str):
|
||||
prompt += message["content"]
|
||||
elif isinstance(message["content"], list):
|
||||
# see https://docs.litellm.ai/docs/providers/openai#openai-vision-models
|
||||
for element in message["content"]:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
prompt += element["text"]
|
||||
elif element["type"] == "image_url":
|
||||
image_url = element["image_url"]["url"]
|
||||
images.append(image_url)
|
||||
# processing images passed to gemini
|
||||
processed_images = []
|
||||
for img in images:
|
||||
if "gs://" in img:
|
||||
# Case 1: Images with Cloud Storage URIs
|
||||
# The supported MIME types for images include image/png and image/jpeg.
|
||||
part_mime = "image/png" if "png" in img else "image/jpeg"
|
||||
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
||||
processed_images.append(google_clooud_part)
|
||||
elif "https:/" in img:
|
||||
# Case 2: Images with direct links
|
||||
image = _load_image_from_url(img)
|
||||
processed_images.append(image)
|
||||
elif ".mp4" in img and "gs://" in img:
|
||||
# Case 3: Videos with Cloud Storage URIs
|
||||
part_mime = "video/mp4"
|
||||
google_clooud_part = Part.from_uri(img, mime_type=part_mime)
|
||||
processed_images.append(google_clooud_part)
|
||||
return prompt, processed_images
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
|
@ -69,6 +171,7 @@ def completion(
|
|||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool=False
|
||||
):
|
||||
try:
|
||||
import vertexai
|
||||
|
@ -77,6 +180,8 @@ def completion(
|
|||
try:
|
||||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
||||
|
||||
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location
|
||||
|
@ -90,34 +195,94 @@ def completion(
|
|||
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
|
||||
|
||||
mode = ""
|
||||
|
||||
request_str = ""
|
||||
if model in litellm.vertex_chat_models:
|
||||
chat_model = ChatModel.from_pretrained(model)
|
||||
response_obj = None
|
||||
if model in litellm.vertex_language_models:
|
||||
llm_model = GenerativeModel(model)
|
||||
mode = ""
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
elif model in litellm.vertex_vision_models:
|
||||
llm_model = GenerativeModel(model)
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
mode = "vision"
|
||||
elif model in litellm.vertex_chat_models:
|
||||
llm_model = ChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"chat_model = ChatModel.from_pretrained({model})\n"
|
||||
request_str += f"llm_model = ChatModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_text_models:
|
||||
text_model = TextGenerationModel.from_pretrained(model)
|
||||
llm_model = TextGenerationModel.from_pretrained(model)
|
||||
mode = "text"
|
||||
request_str += f"text_model = TextGenerationModel.from_pretrained({model})\n"
|
||||
request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n"
|
||||
elif model in litellm.vertex_code_text_models:
|
||||
text_model = CodeGenerationModel.from_pretrained(model)
|
||||
llm_model = CodeGenerationModel.from_pretrained(model)
|
||||
mode = "text"
|
||||
request_str += f"text_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||
else: # vertex_code_chat_models
|
||||
chat_model = CodeChatModel.from_pretrained(model)
|
||||
request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n"
|
||||
else: # vertex_code_llm_models
|
||||
llm_model = CodeChatModel.from_pretrained(model)
|
||||
mode = "chat"
|
||||
request_str += f"chat_model = CodeChatModel.from_pretrained({model})\n"
|
||||
request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n"
|
||||
|
||||
if mode == "chat":
|
||||
chat = chat_model.start_chat()
|
||||
request_str+= f"chat = chat_model.start_chat()\n"
|
||||
if acompletion == True: # [TODO] expand support to vertex ai chat + text models
|
||||
if optional_params.get("stream", False) is True:
|
||||
# async streaming
|
||||
return async_streaming(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, messages=messages, print_verbose=print_verbose, **optional_params)
|
||||
return async_completion(llm_model=llm_model, mode=mode, prompt=prompt, logging_obj=logging_obj, request_str=request_str, model=model, model_response=model_response, encoding=encoding, messages=messages,print_verbose=print_verbose,**optional_params)
|
||||
|
||||
if mode == "":
|
||||
chat = llm_model.start_chat()
|
||||
request_str+= f"chat = llm_model.start_chat()\n"
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
stream=True
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
## LLM Call
|
||||
response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params)
|
||||
)
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
request_str+= f"chat = llm_model.start_chat()\n"
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||
|
@ -125,27 +290,30 @@ def completion(
|
|||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += f"chat.send_message_streaming({prompt}, **{optional_params})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"chat.send_message({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
completion_response = chat.send_message(prompt, **optional_params).text
|
||||
elif mode == "text":
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||
request_str += f"text_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||
request_str += f"llm_model.predict_streaming({prompt}, **{optional_params})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = text_model.predict_streaming(prompt, **optional_params)
|
||||
model_response = llm_model.predict_streaming(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
|
||||
request_str += f"text_model.predict({prompt}, **{optional_params}).text\n"
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
completion_response = text_model.predict(prompt, **optional_params).text
|
||||
completion_response = llm_model.predict(prompt, **optional_params).text
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
|
@ -161,22 +329,162 @@ def completion(
|
|||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
if model in litellm.vertex_language_models and response_obj is not None:
|
||||
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||
else:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
async def async_completion(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, encoding=None, messages = None, print_verbose = None, **optional_params):
|
||||
"""
|
||||
Add support for acompletion calls for gemini-pro
|
||||
"""
|
||||
try:
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
|
||||
if mode == "":
|
||||
# gemini-pro
|
||||
chat = llm_model.start_chat()
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params))
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
|
||||
request_str += f"response = llm_model.generate_content({content})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
## LLM Call
|
||||
response = await llm_model._generate_content_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params)
|
||||
)
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
elif mode == "chat":
|
||||
# chat-bison etc.
|
||||
chat = llm_model.start_chat()
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = await chat.send_message_async(prompt, **optional_params)
|
||||
completion_response = response_obj.text
|
||||
elif mode == "text":
|
||||
# gecko etc.
|
||||
request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = await llm_model.predict_async(prompt, **optional_params)
|
||||
completion_response = response_obj.text
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt, api_key=None, original_response=completion_response
|
||||
)
|
||||
|
||||
## RESPONSE OBJECT
|
||||
if len(str(completion_response)) > 0:
|
||||
model_response["choices"][0]["message"][
|
||||
"content"
|
||||
] = str(completion_response)
|
||||
model_response["choices"][0]["message"]["content"] = str(completion_response)
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
## CALCULATING USAGE
|
||||
if model in litellm.vertex_language_models and response_obj is not None:
|
||||
model_response["choices"][0].finish_reason = response_obj.candidates[0].finish_reason.name
|
||||
usage = Usage(prompt_tokens=response_obj.usage_metadata.prompt_token_count,
|
||||
completion_tokens=response_obj.usage_metadata.candidates_token_count,
|
||||
total_tokens=response_obj.usage_metadata.total_token_count)
|
||||
else:
|
||||
prompt_tokens = len(
|
||||
encoding.encode(prompt)
|
||||
)
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
async def async_streaming(llm_model, mode: str, prompt: str, model: str, model_response: ModelResponse, logging_obj=None, request_str=None, messages = None, print_verbose = None, **optional_params):
|
||||
"""
|
||||
Add support for async streaming calls for gemini-pro
|
||||
"""
|
||||
from vertexai.preview.generative_models import GenerationConfig
|
||||
if mode == "":
|
||||
# gemini-pro
|
||||
chat = llm_model.start_chat()
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"chat.send_message_async({prompt},generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response = await chat.send_message_async(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
||||
prompt, images = _gemini_vision_convert_messages(messages=messages)
|
||||
content = [prompt] + images
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
response = llm_model._generate_content_streaming_async(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
stream=True
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
optional_params.pop("stream", None) # vertex ai raises an error when passing stream in optional params
|
||||
request_str += f"chat.send_message_streaming_async({prompt}, **{optional_params})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response = chat.send_message_streaming_async(prompt, **optional_params)
|
||||
optional_params["stream"] = True
|
||||
elif mode == "text":
|
||||
optional_params.pop("stream", None) # See note above on handling streaming for vertex ai
|
||||
request_str += f"llm_model.predict_streaming_async({prompt}, **{optional_params})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response = llm_model.predict_streaming_async(prompt, **optional_params)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(completion_stream=response, model=model, custom_llm_provider="vertex_ai",logging_obj=logging_obj)
|
||||
async for transformed_chunk in streamwrapper:
|
||||
yield transformed_chunk
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
|
|
251
litellm/main.py
251
litellm/main.py
|
@ -14,6 +14,7 @@ import dotenv, traceback, random, asyncio, time, contextvars
|
|||
from copy import deepcopy
|
||||
import httpx
|
||||
import litellm
|
||||
|
||||
from litellm import ( # type: ignore
|
||||
client,
|
||||
exception_type,
|
||||
|
@ -31,7 +32,8 @@ from litellm.utils import (
|
|||
mock_completion_streaming_obj,
|
||||
convert_to_model_response_object,
|
||||
token_counter,
|
||||
Usage
|
||||
Usage,
|
||||
get_optional_params_embeddings
|
||||
)
|
||||
from .llms import (
|
||||
anthropic,
|
||||
|
@ -171,11 +173,14 @@ async def acompletion(*args, **kwargs):
|
|||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
if kwargs.get("stream", False):
|
||||
response = completion(*args, **kwargs)
|
||||
else:
|
||||
|
@ -200,9 +205,12 @@ async def acompletion(*args, **kwargs):
|
|||
|
||||
async def _async_streaming(response, model, custom_llm_provider, args):
|
||||
try:
|
||||
print_verbose(f"received response in _async_streaming: {response}")
|
||||
async for line in response:
|
||||
print_verbose(f"line in async streaming: {line}")
|
||||
yield line
|
||||
except Exception as e:
|
||||
print_verbose(f"error raised _async_streaming: {traceback.format_exc()}")
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
@ -278,7 +286,7 @@ def completion(
|
|||
|
||||
# Optional liteLLM function params
|
||||
**kwargs,
|
||||
) -> ModelResponse:
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
"""
|
||||
Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||
Parameters:
|
||||
|
@ -319,7 +327,6 @@ def completion(
|
|||
######### unpacking kwargs #####################
|
||||
args = locals()
|
||||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||
logger_fn = kwargs.get('logger_fn', None)
|
||||
|
@ -344,13 +351,14 @@ def completion(
|
|||
final_prompt_value = kwargs.get("final_prompt_value", None)
|
||||
bos_token = kwargs.get("bos_token", None)
|
||||
eos_token = kwargs.get("eos_token", None)
|
||||
preset_cache_key = kwargs.get("preset_cache_key", None)
|
||||
hf_model_name = kwargs.get("hf_model_name", None)
|
||||
### ASYNC CALLS ###
|
||||
acompletion = kwargs.get("acompletion", False)
|
||||
client = kwargs.get("client", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key", "caching_groups"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -384,7 +392,6 @@ def completion(
|
|||
model=deployment_id
|
||||
custom_llm_provider="azure"
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||
|
||||
### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ###
|
||||
if input_cost_per_token is not None and output_cost_per_token is not None:
|
||||
litellm.register_model({
|
||||
|
@ -448,7 +455,6 @@ def completion(
|
|||
# For logging - save the values of the litellm-specific params passed in
|
||||
litellm_params = get_litellm_params(
|
||||
acompletion=acompletion,
|
||||
return_async=return_async,
|
||||
api_key=api_key,
|
||||
force_timeout=force_timeout,
|
||||
logger_fn=logger_fn,
|
||||
|
@ -460,7 +466,8 @@ def completion(
|
|||
completion_call_id=id,
|
||||
metadata=metadata,
|
||||
model_info=model_info,
|
||||
proxy_server_request=proxy_server_request
|
||||
proxy_server_request=proxy_server_request,
|
||||
preset_cache_key=preset_cache_key
|
||||
)
|
||||
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
|
||||
if custom_llm_provider == "azure":
|
||||
|
@ -524,23 +531,25 @@ def completion(
|
|||
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
elif (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openai"
|
||||
or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo
|
||||
): # allow user to make an openai call with a custom base
|
||||
|
@ -604,19 +613,19 @@ def completion(
|
|||
)
|
||||
raise e
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "text-completion-openai"
|
||||
or "ft:babbage-002" in model
|
||||
or "ft:davinci-002" in model # support for finetuned completion models
|
||||
):
|
||||
# print("calling custom openai provider")
|
||||
openai.api_type = "openai"
|
||||
|
||||
api_base = (
|
||||
|
@ -655,17 +664,6 @@ def completion(
|
|||
prompt = messages[0]["content"]
|
||||
else:
|
||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||
## LOGGING
|
||||
logging.pre_call(
|
||||
input=prompt,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"openai_organization": litellm.organization,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
"api_type": openai.api_type,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
model_response = openai_text_completions.completion(
|
||||
model=model,
|
||||
|
@ -681,9 +679,14 @@ def completion(
|
|||
logger_fn=logger_fn
|
||||
)
|
||||
|
||||
# if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# response = CustomStreamWrapper(model_response, model, custom_llm_provider="text-completion-openai", logging_obj=logging)
|
||||
# return response
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=model_response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
response = model_response
|
||||
elif (
|
||||
"replicate" in model or
|
||||
|
@ -728,8 +731,16 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate")
|
||||
return response
|
||||
model_response = CustomStreamWrapper(model_response, model, logging_obj=logging, custom_llm_provider="replicate") # type: ignore
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=replicate_key,
|
||||
original_response=model_response,
|
||||
)
|
||||
|
||||
response = model_response
|
||||
|
||||
elif custom_llm_provider=="anthropic":
|
||||
|
@ -749,7 +760,7 @@ def completion(
|
|||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = anthropic.completion(
|
||||
response = anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -765,9 +776,16 @@ def completion(
|
|||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider="anthropic", logging_obj=logging)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
response = response
|
||||
elif custom_llm_provider == "nlp_cloud":
|
||||
nlp_cloud_key = (
|
||||
api_key or litellm.nlp_cloud_key or get_secret("NLP_CLOUD_API_KEY") or litellm.api_key
|
||||
|
@ -780,7 +798,7 @@ def completion(
|
|||
or "https://api.nlpcloud.io/v1/gpu/"
|
||||
)
|
||||
|
||||
model_response = nlp_cloud.completion(
|
||||
response = nlp_cloud.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
|
@ -796,9 +814,17 @@ def completion(
|
|||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(model_response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||
return response
|
||||
response = model_response
|
||||
response = CustomStreamWrapper(response, model, custom_llm_provider="nlp_cloud", logging_obj=logging)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
response = response
|
||||
elif custom_llm_provider == "aleph_alpha":
|
||||
aleph_alpha_key = (
|
||||
api_key or litellm.aleph_alpha_key or get_secret("ALEPH_ALPHA_API_KEY") or get_secret("ALEPHALPHA_API_KEY") or litellm.api_key
|
||||
|
@ -1100,7 +1126,7 @@ def completion(
|
|||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif model in litellm.vertex_chat_models or model in litellm.vertex_code_chat_models or model in litellm.vertex_text_models or model in litellm.vertex_code_text_models:
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = (litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT"))
|
||||
vertex_ai_location = (litellm.vertex_location
|
||||
|
@ -1117,10 +1143,11 @@ def completion(
|
|||
encoding=encoding,
|
||||
vertex_location=vertex_ai_location,
|
||||
vertex_project=vertex_ai_project,
|
||||
logging_obj=logging
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] == True and acompletion == False:
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
|
||||
)
|
||||
|
@ -1186,6 +1213,7 @@ def completion(
|
|||
# "SageMaker is currently not supporting streaming responses."
|
||||
|
||||
# fake streaming for sagemaker
|
||||
print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER")
|
||||
resp_string = model_response["choices"][0]["message"]["content"]
|
||||
response = CustomStreamWrapper(
|
||||
resp_string, model, custom_llm_provider="sagemaker", logging_obj=logging
|
||||
|
@ -1200,7 +1228,7 @@ def completion(
|
|||
custom_prompt_dict
|
||||
or litellm.custom_prompt_dict
|
||||
)
|
||||
model_response = bedrock.completion(
|
||||
response = bedrock.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_prompt_dict=litellm.custom_prompt_dict,
|
||||
|
@ -1218,16 +1246,24 @@ def completion(
|
|||
# don't try to access stream object,
|
||||
if "ai21" in model:
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
response, model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
else:
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
iter(response), model, custom_llm_provider="bedrock", logging_obj=logging
|
||||
)
|
||||
return response
|
||||
|
||||
if optional_params.get("stream", False):
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
original_response=response,
|
||||
)
|
||||
|
||||
|
||||
## RESPONSE OBJECT
|
||||
response = model_response
|
||||
response = response
|
||||
elif custom_llm_provider == "vllm":
|
||||
model_response = vllm.completion(
|
||||
model=model,
|
||||
|
@ -1273,14 +1309,18 @@ def completion(
|
|||
)
|
||||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages, custom_llm_provider=custom_llm_provider)
|
||||
## LOGGING
|
||||
if kwargs.get('acompletion', False) == True:
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
async_generator = ollama.async_get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
||||
return async_generator
|
||||
if isinstance(prompt, dict):
|
||||
# for multimode models - ollama/llava prompt_factory returns a dict {
|
||||
# "prompt": prompt,
|
||||
# "images": images
|
||||
# }
|
||||
prompt, images = prompt["prompt"], prompt["images"]
|
||||
optional_params["images"] = images
|
||||
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging)
|
||||
## LOGGING
|
||||
generator = ollama.get_ollama_response_stream(api_base, model, prompt, optional_params, logging_obj=logging, acompletion=acompletion, model_response=model_response, encoding=encoding)
|
||||
if acompletion is True:
|
||||
return generator
|
||||
if optional_params.get("stream", False) == True:
|
||||
# assume all ollama responses are streamed
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -1716,8 +1756,7 @@ async def aembedding(*args, **kwargs):
|
|||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "huggingface"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
or custom_llm_provider == "perplexity"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
|
@ -1781,22 +1820,21 @@ def embedding(
|
|||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", None)
|
||||
encoding_format = kwargs.get("encoding_format", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.pop("aembedding", None)
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info"]
|
||||
aembedding = kwargs.get("aembedding", None)
|
||||
openai_params = ["user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "max_retries", "encoding_format"]
|
||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key", "caching_groups"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = {}
|
||||
for param in non_default_params:
|
||||
optional_params[param] = kwargs[param]
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||
|
||||
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key)
|
||||
optional_params = get_optional_params_embeddings(user=user, encoding_format=encoding_format, custom_llm_provider=custom_llm_provider, **non_default_params)
|
||||
try:
|
||||
response = None
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info})
|
||||
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}})
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
@ -1936,7 +1974,7 @@ def embedding(
|
|||
## LOGGING
|
||||
logging.post_call(
|
||||
input=input,
|
||||
api_key=openai.api_key,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
## Map to OpenAI Exception
|
||||
|
@ -1948,6 +1986,59 @@ def embedding(
|
|||
|
||||
|
||||
###### Text Completion ################
|
||||
async def atext_completion(*args, **kwargs):
|
||||
"""
|
||||
Implemented to handle async streaming for the text completion endpoint
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
||||
### PASS ARGS TO COMPLETION ###
|
||||
kwargs["acompletion"] = True
|
||||
custom_llm_provider = None
|
||||
try:
|
||||
# Use a partial function to pass your keyword arguments
|
||||
func = partial(text_completion, *args, **kwargs)
|
||||
|
||||
# Add the context to the function
|
||||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
|
||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
||||
|
||||
if (custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "mistral"
|
||||
or custom_llm_provider == "openrouter"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
or custom_llm_provider == "huggingface"
|
||||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||
if kwargs.get("stream", False):
|
||||
response = text_completion(*args, **kwargs)
|
||||
else:
|
||||
# Await normally
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False): # return an async generator
|
||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
||||
else:
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
raise exception_type(
|
||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
||||
)
|
||||
|
||||
def text_completion(
|
||||
prompt: Union[str, List[Union[str, List[Union[str, List[int]]]]]], # Required: The prompt(s) to generate completions for.
|
||||
model: Optional[str]=None, # Optional: either `model` or `engine` can be set
|
||||
|
@ -2079,7 +2170,7 @@ def text_completion(
|
|||
*args,
|
||||
**all_params,
|
||||
)
|
||||
#print(response)
|
||||
|
||||
text_completion_response["id"] = response.get("id", None)
|
||||
text_completion_response["object"] = "text_completion"
|
||||
text_completion_response["created"] = response.get("created", None)
|
||||
|
@ -2294,6 +2385,8 @@ def stream_chunk_builder(chunks: list, messages: Optional[list]=None):
|
|||
completion_output = combined_content
|
||||
elif len(combined_arguments) > 0:
|
||||
completion_output = combined_arguments
|
||||
else:
|
||||
completion_output = ""
|
||||
# # Update usage information if needed
|
||||
try:
|
||||
response["usage"]["prompt_tokens"] = token_counter(model=model, messages=messages)
|
||||
|
|
|
@ -41,6 +41,20 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.0000015,
|
||||
|
@ -62,6 +76,13 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.0000010,
|
||||
"output_cost_per_token": 0.0000020,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"gpt-3.5-turbo-16k": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.000003,
|
||||
|
@ -76,6 +97,62 @@
|
|||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"ft:gpt-3.5-turbo": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.000012,
|
||||
"output_cost_per_token": 0.000016,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
"text-embedding-ada-002": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"azure/gpt-4-1106-preview": {
|
||||
"max_tokens": 128000,
|
||||
"input_cost_per_token": 0.00001,
|
||||
"output_cost_per_token": 0.00003,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-4-32k": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.00006,
|
||||
"output_cost_per_token": 0.00012,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-4": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.00003,
|
||||
"output_cost_per_token": 0.00006,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-3.5-turbo-16k": {
|
||||
"max_tokens": 16385,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000004,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/gpt-3.5-turbo": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.0000015,
|
||||
"output_cost_per_token": 0.000002,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "chat"
|
||||
},
|
||||
"azure/text-embedding-ada-002": {
|
||||
"max_tokens": 8191,
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "azure",
|
||||
"mode": "embedding"
|
||||
},
|
||||
"text-davinci-003": {
|
||||
"max_tokens": 4097,
|
||||
"input_cost_per_token": 0.000002,
|
||||
|
@ -127,6 +204,7 @@
|
|||
},
|
||||
"claude-instant-1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"litellm_provider": "anthropic",
|
||||
|
@ -134,15 +212,25 @@
|
|||
},
|
||||
"claude-instant-1.2": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000000163,
|
||||
"output_cost_per_token": 0.000000551,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat"
|
||||
},
|
||||
"claude-2": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat"
|
||||
},
|
||||
"claude-2.1": {
|
||||
"max_tokens": 200000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "anthropic",
|
||||
"mode": "chat"
|
||||
},
|
||||
|
@ -227,9 +315,51 @@
|
|||
"max_tokens": 32000,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "vertex_ai-chat-models",
|
||||
"litellm_provider": "vertex_ai-code-chat-models",
|
||||
"mode": "chat"
|
||||
},
|
||||
"palm/chat-bison": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "chat"
|
||||
},
|
||||
"palm/chat-bison-001": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "chat"
|
||||
},
|
||||
"palm/text-bison": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"palm/text-bison-001": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"palm/text-bison-safety-off": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"palm/text-bison-safety-recitation-off": {
|
||||
"max_tokens": 8196,
|
||||
"input_cost_per_token": 0.000000125,
|
||||
"output_cost_per_token": 0.000000125,
|
||||
"litellm_provider": "palm",
|
||||
"mode": "completion"
|
||||
},
|
||||
"command-nightly": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000015,
|
||||
|
@ -267,6 +397,8 @@
|
|||
},
|
||||
"replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000,
|
||||
"output_cost_per_token": 0.0000,
|
||||
"litellm_provider": "replicate",
|
||||
"mode": "chat"
|
||||
},
|
||||
|
@ -293,6 +425,7 @@
|
|||
},
|
||||
"openrouter/anthropic/claude-instant-v1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"litellm_provider": "openrouter",
|
||||
|
@ -300,6 +433,7 @@
|
|||
},
|
||||
"openrouter/anthropic/claude-2": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"litellm_provider": "openrouter",
|
||||
|
@ -496,20 +630,31 @@
|
|||
},
|
||||
"anthropic.claude-v1": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anthropic.claude-v2": {
|
||||
"max_tokens": 100000,
|
||||
"input_cost_per_token": 0.00001102,
|
||||
"output_cost_per_token": 0.00003268,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anthropic.claude-v2:1": {
|
||||
"max_tokens": 200000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.000008,
|
||||
"output_cost_per_token": 0.000024,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anthropic.claude-instant-v1": {
|
||||
"max_tokens": 100000,
|
||||
"max_output_tokens": 8191,
|
||||
"input_cost_per_token": 0.00000163,
|
||||
"output_cost_per_token": 0.00000551,
|
||||
"litellm_provider": "bedrock",
|
||||
|
@ -529,26 +674,80 @@
|
|||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"meta.llama2-70b-chat-v1": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000195,
|
||||
"output_cost_per_token": 0.00000256,
|
||||
"litellm_provider": "bedrock",
|
||||
"mode": "chat"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "completion"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-7b-f": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "chat"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-13b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "completion"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-13b-f": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "chat"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-70b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "completion"
|
||||
},
|
||||
"sagemaker/meta-textgeneration-llama-2-70b-b-f": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000,
|
||||
"output_cost_per_token": 0.000,
|
||||
"litellm_provider": "sagemaker",
|
||||
"mode": "chat"
|
||||
},
|
||||
"together-ai-up-to-3b": {
|
||||
"input_cost_per_token": 0.0000001,
|
||||
"output_cost_per_token": 0.0000001
|
||||
"output_cost_per_token": 0.0000001,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-3.1b-7b": {
|
||||
"input_cost_per_token": 0.0000002,
|
||||
"output_cost_per_token": 0.0000002
|
||||
"output_cost_per_token": 0.0000002,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-7.1b-20b": {
|
||||
"max_tokens": 1000,
|
||||
"input_cost_per_token": 0.0000004,
|
||||
"output_cost_per_token": 0.0000004
|
||||
"output_cost_per_token": 0.0000004,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-20.1b-40b": {
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001
|
||||
"input_cost_per_token": 0.0000008,
|
||||
"output_cost_per_token": 0.0000008,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"together-ai-40.1b-70b": {
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000003
|
||||
"input_cost_per_token": 0.0000009,
|
||||
"output_cost_per_token": 0.0000009,
|
||||
"litellm_provider": "together_ai"
|
||||
},
|
||||
"ollama/llama2": {
|
||||
"max_tokens": 4096,
|
||||
|
@ -578,10 +777,38 @@
|
|||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/mistral": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/codellama": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/orca-mini": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"ollama/vicuna": {
|
||||
"max_tokens": 2048,
|
||||
"input_cost_per_token": 0.0,
|
||||
"output_cost_per_token": 0.0,
|
||||
"litellm_provider": "ollama",
|
||||
"mode": "completion"
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-2-70b-chat-hf": {
|
||||
"max_tokens": 6144,
|
||||
"input_cost_per_token": 0.000001875,
|
||||
"output_cost_per_token": 0.000001875,
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000000700,
|
||||
"output_cost_per_token": 0.000000950,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat"
|
||||
},
|
||||
|
@ -619,5 +846,103 @@
|
|||
"output_cost_per_token": 0.00000095,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-7b-chat": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-70b-chat": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-7b-online": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.0005,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/pplx-70b-online": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.0005,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/llama-2-13b-chat": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/llama-2-70b-chat": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/mistral-7b-instruct": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"perplexity/replit-code-v1.5-3b": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.0000000,
|
||||
"output_cost_per_token": 0.000000,
|
||||
"litellm_provider": "perplexity",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/mistralai/Mistral-7B-Instruct-v0.1": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/HuggingFaceH4/zephyr-7b-beta": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-7b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000015,
|
||||
"output_cost_per_token": 0.00000015,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-13b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000025,
|
||||
"output_cost_per_token": 0.00000025,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/meta-llama/Llama-2-70b-chat-hf": {
|
||||
"max_tokens": 4096,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
},
|
||||
"anyscale/codellama/CodeLlama-34b-Instruct-hf": {
|
||||
"max_tokens": 16384,
|
||||
"input_cost_per_token": 0.000001,
|
||||
"output_cost_per_token": 0.000001,
|
||||
"litellm_provider": "anyscale",
|
||||
"mode": "chat"
|
||||
}
|
||||
}
|
||||
|
|
4
litellm/proxy/_experimental/post_call_rules.py
Normal file
4
litellm/proxy/_experimental/post_call_rules.py
Normal file
|
@ -0,0 +1,4 @@
|
|||
def my_custom_rule(input): # receives the model response
|
||||
# if len(input) < 5: # trigger fallback if the model response is too short
|
||||
return False
|
||||
return True
|
|
@ -2,8 +2,21 @@ from pydantic import BaseModel, Extra, Field, root_validator
|
|||
from typing import Optional, List, Union, Dict, Literal
|
||||
from datetime import datetime
|
||||
import uuid, json
|
||||
|
||||
class LiteLLMBase(BaseModel):
|
||||
"""
|
||||
Implements default functions, all pydantic objects should have.
|
||||
"""
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
||||
######### Request Class Definition ######
|
||||
class ProxyChatCompletionRequest(BaseModel):
|
||||
class ProxyChatCompletionRequest(LiteLLMBase):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = None
|
||||
|
@ -38,16 +51,16 @@ class ProxyChatCompletionRequest(BaseModel):
|
|||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
class ModelInfoDelete(BaseModel):
|
||||
class ModelInfoDelete(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
class ModelInfo(LiteLLMBase):
|
||||
id: Optional[str]
|
||||
mode: Optional[Literal['embedding', 'chat', 'completion']]
|
||||
input_cost_per_token: Optional[float]
|
||||
output_cost_per_token: Optional[float]
|
||||
max_tokens: Optional[int]
|
||||
input_cost_per_token: Optional[float] = 0.0
|
||||
output_cost_per_token: Optional[float] = 0.0
|
||||
max_tokens: Optional[int] = 2048 # assume 2048 if not set
|
||||
|
||||
# for azure models we need users to specify the base model, one azure you can call deployments - azure/my-random-model
|
||||
# we look up the base model in model_prices_and_context_window.json
|
||||
|
@ -65,38 +78,41 @@ class ModelInfo(BaseModel):
|
|||
class Config:
|
||||
extra = Extra.allow # Allow extra fields
|
||||
protected_namespaces = ()
|
||||
|
||||
|
||||
# @root_validator(pre=True)
|
||||
# def set_model_info(cls, values):
|
||||
# if values.get("id") is None:
|
||||
# values.update({"id": str(uuid.uuid4())})
|
||||
# if values.get("mode") is None:
|
||||
# values.update({"mode": str(uuid.uuid4())})
|
||||
# return values
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("id") is None:
|
||||
values.update({"id": str(uuid.uuid4())})
|
||||
if values.get("mode") is None:
|
||||
values.update({"mode": None})
|
||||
if values.get("input_cost_per_token") is None:
|
||||
values.update({"input_cost_per_token": None})
|
||||
if values.get("output_cost_per_token") is None:
|
||||
values.update({"output_cost_per_token": None})
|
||||
if values.get("max_tokens") is None:
|
||||
values.update({"max_tokens": None})
|
||||
if values.get("base_model") is None:
|
||||
values.update({"base_model": None})
|
||||
return values
|
||||
|
||||
|
||||
|
||||
class ModelParams(BaseModel):
|
||||
class ModelParams(LiteLLMBase):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: Optional[ModelInfo]=None
|
||||
model_info: ModelInfo
|
||||
|
||||
# def __init__(self, model_name: str, litellm_params: dict, model_info: Optional[ModelInfo] = None):
|
||||
# self.model_name = model_name
|
||||
# self.litellm_params = litellm_params
|
||||
# self.model_info = model_info if model_info else ModelInfo()
|
||||
# super.__init__(model_name=self.model_name, litellm_params=self.litellm_params, model_info=self.model_info)
|
||||
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
# @root_validator(pre=True)
|
||||
# def set_model_info(cls, values):
|
||||
# if values.get("model_info") is None:
|
||||
# values.update({"model_info": ModelInfo()})
|
||||
# return values
|
||||
@root_validator(pre=True)
|
||||
def set_model_info(cls, values):
|
||||
if values.get("model_info") is None:
|
||||
values.update({"model_info": ModelInfo()})
|
||||
return values
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
class GenerateKeyRequest(LiteLLMBase):
|
||||
duration: Optional[str] = "1h"
|
||||
models: Optional[list] = []
|
||||
aliases: Optional[dict] = {}
|
||||
|
@ -105,26 +121,32 @@ class GenerateKeyRequest(BaseModel):
|
|||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
|
||||
def json(self, **kwargs):
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
class UpdateKeyRequest(LiteLLMBase):
|
||||
key: str
|
||||
duration: Optional[str] = None
|
||||
models: Optional[list] = None
|
||||
aliases: Optional[dict] = None
|
||||
config: Optional[dict] = None
|
||||
spend: Optional[float] = None
|
||||
user_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
class GenerateKeyResponse(LiteLLMBase):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
|
||||
|
||||
|
||||
class _DeleteKeyObject(LiteLLMBase):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(BaseModel):
|
||||
class DeleteKeyRequest(LiteLLMBase):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||
"""
|
||||
Return the row in the db
|
||||
"""
|
||||
|
@ -137,7 +159,7 @@ class UserAPIKeyAuth(BaseModel): # the expected response object for user api key
|
|||
max_parallel_requests: Optional[int] = None
|
||||
duration: str = "1h"
|
||||
|
||||
class ConfigGeneralSettings(BaseModel):
|
||||
class ConfigGeneralSettings(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by `general_settings` in config.yaml
|
||||
"""
|
||||
|
@ -153,10 +175,12 @@ class ConfigGeneralSettings(BaseModel):
|
|||
health_check_interval: int = Field(300, description="background health check interval in seconds")
|
||||
|
||||
|
||||
class ConfigYAML(BaseModel):
|
||||
class ConfigYAML(LiteLLMBase):
|
||||
"""
|
||||
Documents all the fields supported by the config.yaml
|
||||
"""
|
||||
model_list: Optional[List[ModelParams]] = Field(None, description="List of supported models on the server, with model-specific configs")
|
||||
litellm_settings: Optional[dict] = Field(None, description="litellm Module settings. See __init__.py for all, example litellm.drop_params=True, litellm.set_verbose=True, litellm.api_base, litellm.cache")
|
||||
general_settings: Optional[ConfigGeneralSettings] = None
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
|
|
@ -1,3 +1,11 @@
|
|||
import sys, os, traceback
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
import litellm
|
||||
import inspect
|
||||
|
@ -36,9 +44,12 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose("On Success!")
|
||||
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"On Async Success!")
|
||||
response_cost = litellm.completion_cost(completion_response=response_obj)
|
||||
assert response_cost > 0.0
|
||||
return
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
|
|
@ -69,7 +69,6 @@ async def _perform_health_check(model_list: list):
|
|||
for model in model_list:
|
||||
litellm_params = model["litellm_params"]
|
||||
model_info = model.get("model_info", {})
|
||||
litellm_params["model"] = litellm.utils.remove_model_id(litellm_params["model"])
|
||||
litellm_params["messages"] = _get_random_llm_message()
|
||||
|
||||
prepped_params.append(litellm_params)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional
|
||||
import litellm
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
@ -14,24 +15,28 @@ class MaxParallelRequestsHandler(CustomLogger):
|
|||
print(print_statement) # noqa
|
||||
|
||||
|
||||
async def max_parallel_request_allow_request(self, max_parallel_requests: Optional[int], api_key: Optional[str], user_api_key_cache: DualCache):
|
||||
async def async_pre_call_hook(self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str):
|
||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
max_parallel_requests = user_api_key_dict.max_parallel_requests
|
||||
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
if max_parallel_requests is None:
|
||||
return
|
||||
|
||||
self.user_api_key_cache = user_api_key_cache # save the api key cache for updating the value
|
||||
self.user_api_key_cache = cache # save the api key cache for updating the value
|
||||
|
||||
# CHECK IF REQUEST ALLOWED
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
current = user_api_key_cache.get_cache(key=request_count_api_key)
|
||||
current = cache.get_cache(key=request_count_api_key)
|
||||
self.print_verbose(f"current: {current}")
|
||||
if current is None:
|
||||
user_api_key_cache.set_cache(request_count_api_key, 1)
|
||||
cache.set_cache(request_count_api_key, 1)
|
||||
elif int(current) < max_parallel_requests:
|
||||
# Increase count for this token
|
||||
user_api_key_cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
cache.set_cache(request_count_api_key, int(current) + 1)
|
||||
else:
|
||||
raise HTTPException(status_code=429, detail="Max parallel request limit reached.")
|
||||
|
||||
|
@ -55,16 +60,24 @@ class MaxParallelRequestsHandler(CustomLogger):
|
|||
except Exception as e:
|
||||
self.print_verbose(e) # noqa
|
||||
|
||||
async def async_log_failure_call(self, api_key, user_api_key_cache):
|
||||
async def async_log_failure_call(self, user_api_key_dict: UserAPIKeyAuth, original_exception: Exception):
|
||||
try:
|
||||
self.print_verbose(f"Inside Max Parallel Request Failure Hook")
|
||||
api_key = user_api_key_dict.api_key
|
||||
if api_key is None:
|
||||
return
|
||||
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
request_count_api_key = f"{api_key}_request_count"
|
||||
# Decrease count for this token
|
||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or 1
|
||||
new_val = current - 1
|
||||
self.print_verbose(f"updated_value in failure call: {new_val}")
|
||||
self.user_api_key_cache.set_cache(request_count_api_key, new_val)
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurred - {str(e)}") # noqa
|
|
@ -3,6 +3,7 @@ import subprocess, traceback, json
|
|||
import os, sys
|
||||
import random, appdirs
|
||||
from datetime import datetime
|
||||
import importlib
|
||||
from dotenv import load_dotenv
|
||||
import operator
|
||||
sys.path.append(os.getcwd())
|
||||
|
@ -76,13 +77,14 @@ def is_port_in_use(port):
|
|||
@click.option('--config', '-c', default=None, help='Path to the proxy configuration file (e.g. config.yaml). Usage `litellm --config config.yaml`')
|
||||
@click.option('--max_budget', default=None, type=float, help='Set max budget for API calls - works for hosted models like OpenAI, TogetherAI, Anthropic, etc.`')
|
||||
@click.option('--telemetry', default=True, type=bool, help='Helps us know if people are using this feature. Turn this off by doing `--telemetry False`')
|
||||
@click.option('--version', '-v', default=False, is_flag=True, type=bool, help='Print LiteLLM version')
|
||||
@click.option('--logs', flag_value=False, type=int, help='Gets the "n" most recent logs. By default gets most recent log.')
|
||||
@click.option('--health', flag_value=True, help='Make a chat/completions request to all llms in config.yaml')
|
||||
@click.option('--test', flag_value=True, help='proxy chat completions url to make a test request to')
|
||||
@click.option('--test_async', default=False, is_flag=True, help='Calls async endpoints /queue/requests and /queue/response')
|
||||
@click.option('--num_requests', default=10, type=int, help='Number of requests to hit async endpoint with')
|
||||
@click.option('--local', is_flag=True, default=False, help='for local debugging')
|
||||
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health):
|
||||
def run_server(host, port, api_base, api_version, model, alias, add_key, headers, save, debug, temperature, max_tokens, request_timeout, drop_params, add_function_to_prompt, config, max_budget, telemetry, logs, test, local, num_workers, test_async, num_requests, use_queue, health, version):
|
||||
global feature_telemetry
|
||||
args = locals()
|
||||
if local:
|
||||
|
@ -113,6 +115,10 @@ def run_server(host, port, api_base, api_version, model, alias, add_key, headers
|
|||
except:
|
||||
raise Exception("LiteLLM: No logs saved!")
|
||||
return
|
||||
if version == True:
|
||||
pkg_version = importlib.metadata.version("litellm")
|
||||
click.echo(f'\nLiteLLM: Current Version = {pkg_version}\n')
|
||||
return
|
||||
if model and "ollama" in model and api_base is None:
|
||||
run_ollama_serve()
|
||||
if test_async is True:
|
||||
|
|
|
@ -11,8 +11,10 @@ model_list:
|
|||
output_cost_per_token: 0.00003
|
||||
max_tokens: 4096
|
||||
base_model: gpt-3.5-turbo
|
||||
|
||||
- model_name: openai-gpt-3.5
|
||||
- model_name: BEDROCK_GROUP
|
||||
litellm_params:
|
||||
model: bedrock/cohere.command-text-v14
|
||||
- model_name: Azure OpenAI GPT-4 Canada-East (External)
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
@ -41,11 +43,12 @@ model_list:
|
|||
mode: completion
|
||||
|
||||
litellm_settings:
|
||||
# cache: True
|
||||
# setting callback class
|
||||
# callbacks: custom_callbacks.proxy_handler_instance # sets litellm.callbacks = [proxy_handler_instance]
|
||||
model_group_alias_map: {"gpt-4": "openai-gpt-3.5"} # all requests with gpt-4 model_name, get sent to openai-gpt-3.5
|
||||
|
||||
|
||||
general_settings:
|
||||
|
||||
environment_variables:
|
||||
# otel: True # OpenTelemetry Logger
|
||||
# master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||
|
|
|
@ -195,8 +195,10 @@ prisma_client: Optional[PrismaClient] = None
|
|||
user_api_key_cache = DualCache()
|
||||
user_custom_auth = None
|
||||
use_background_health_checks = None
|
||||
use_queue = False
|
||||
health_check_interval = None
|
||||
health_check_results = {}
|
||||
queue: List = []
|
||||
### INITIALIZE GLOBAL LOGGING OBJECT ###
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache)
|
||||
### REDIS QUEUE ###
|
||||
|
@ -252,51 +254,58 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if api_key is None: # only require api key if master key is set
|
||||
raise Exception(f"No api key passed in.")
|
||||
|
||||
route = request.url.path
|
||||
route: str = request.url.path
|
||||
|
||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
||||
if is_master_key_valid:
|
||||
return UserAPIKeyAuth(api_key=master_key)
|
||||
|
||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||
if route.startswith("/key/") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate, delete, update or get info for new keys")
|
||||
|
||||
if prisma_client:
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
raise Exception("No connected db.")
|
||||
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
print(f"api key: {api_key}")
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
print(f"valid token from prisma: {valid_token}")
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
try:
|
||||
data = await request.json()
|
||||
except json.JSONDecodeError:
|
||||
data = {} # Provide a default value, such as an empty dictionary
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -310,24 +319,12 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
def prisma_setup(database_url: Optional[str]):
|
||||
global prisma_client, proxy_logging_obj, user_api_key_cache
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks()
|
||||
if database_url is not None:
|
||||
try:
|
||||
prisma_client = PrismaClient(database_url=database_url, proxy_logging_obj=proxy_logging_obj)
|
||||
except Exception as e:
|
||||
print("Error when initializing prisma, Ensure you run pip install prisma", e)
|
||||
|
||||
def celery_setup(use_queue: bool):
|
||||
global celery_fn, celery_app_conn, async_result
|
||||
if use_queue:
|
||||
from litellm.proxy.queue.celery_worker import start_worker
|
||||
from litellm.proxy.queue.celery_app import celery_app, process_job
|
||||
from celery.result import AsyncResult
|
||||
start_worker(os.getcwd())
|
||||
celery_fn = process_job
|
||||
async_result = AsyncResult
|
||||
celery_app_conn = celery_app
|
||||
|
||||
def load_from_azure_key_vault(use_azure_key_vault: bool = False):
|
||||
if use_azure_key_vault is False:
|
||||
return
|
||||
|
@ -380,30 +377,14 @@ async def track_cost_callback(
|
|||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
completion_response=kwargs["complete_streaming_response"]
|
||||
input_text = kwargs["messages"]
|
||||
output_text = completion_response["choices"][0]["message"]["content"]
|
||||
response_cost = litellm.completion_cost(
|
||||
model = kwargs["model"],
|
||||
messages = input_text,
|
||||
completion=output_text
|
||||
)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
print("streaming response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
input_text = kwargs.get("messages", "")
|
||||
print(f"type of input_text: {type(input_text)}")
|
||||
if isinstance(input_text, list):
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
||||
elif isinstance(input_text, str):
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
||||
print(f"received completion response: {completion_response}")
|
||||
|
||||
print(f"regular response_cost: {response_cost}")
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
except Exception as e:
|
||||
|
@ -459,7 +440,7 @@ async def _run_background_health_check():
|
|||
await asyncio.sleep(health_check_interval)
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval, use_queue
|
||||
config = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
|
@ -504,6 +485,18 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
cache_port = litellm.get_secret("REDIS_PORT", None)
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
|
||||
|
||||
cache_params = {
|
||||
"type": cache_type,
|
||||
"host": cache_host,
|
||||
"port": cache_port,
|
||||
"password": cache_password
|
||||
}
|
||||
|
||||
if "cache_params" in litellm_settings:
|
||||
cache_params_in_config = litellm_settings["cache_params"]
|
||||
# overwrie cache_params with cache_params_in_config
|
||||
cache_params.update(cache_params_in_config)
|
||||
|
||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
||||
print(f"{blue_color_code}Cache Host:{reset_color_code} {cache_host}")
|
||||
|
@ -513,15 +506,15 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
|
||||
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
|
||||
litellm.cache = Cache(
|
||||
type=cache_type,
|
||||
host=cache_host,
|
||||
port=cache_port,
|
||||
password=cache_password
|
||||
**cache_params
|
||||
)
|
||||
print(f"{blue_color_code}Set Cache on LiteLLM Proxy: {litellm.cache.cache}{reset_color_code} {cache_password}")
|
||||
elif key == "callbacks":
|
||||
litellm.callbacks = [get_instance_fn(value=value, config_file_path=config_file_path)]
|
||||
print_verbose(f"{blue_color_code} Initialized Callbacks - {litellm.callbacks} {reset_color_code}")
|
||||
elif key == "post_call_rules":
|
||||
litellm.post_call_rules = [get_instance_fn(value=value, config_file_path=config_file_path)]
|
||||
print(f"litellm.post_call_rules: {litellm.post_call_rules}")
|
||||
elif key == "success_callback":
|
||||
litellm.success_callback = []
|
||||
|
||||
|
@ -533,10 +526,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
# these are litellm callbacks - "langfuse", "sentry", "wandb"
|
||||
else:
|
||||
litellm.success_callback.append(callback)
|
||||
if callback == "traceloop":
|
||||
from traceloop.sdk import Traceloop
|
||||
print_verbose(f"{blue_color_code} Initializing Traceloop SDK - \nRunning:`Traceloop.init(app_name='Litellm-Server', disable_batch=True)`")
|
||||
Traceloop.init(app_name="Litellm-Server", disable_batch=True)
|
||||
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.success_callback} {reset_color_code}")
|
||||
elif key == "failure_callback":
|
||||
litellm.failure_callback = []
|
||||
|
@ -550,6 +539,10 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
else:
|
||||
litellm.failure_callback.append(callback)
|
||||
print_verbose(f"{blue_color_code} Initialized Success Callbacks - {litellm.failure_callback} {reset_color_code}")
|
||||
elif key == "cache_params":
|
||||
# this is set in the cache branch
|
||||
# see usage here: https://docs.litellm.ai/docs/proxy/caching
|
||||
pass
|
||||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
|
@ -572,7 +565,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
cost_tracking()
|
||||
### START REDIS QUEUE ###
|
||||
use_queue = general_settings.get("use_queue", False)
|
||||
celery_setup(use_queue=use_queue)
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get("master_key", None)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
|
@ -683,6 +675,8 @@ async def generate_key_helper_fn(duration: Optional[str], models: list, aliases:
|
|||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||||
return {"token": token, "expires": new_verification_token.expires, "user_id": user_id}
|
||||
|
||||
|
||||
|
||||
async def delete_verification_token(tokens: List):
|
||||
global prisma_client
|
||||
try:
|
||||
|
@ -761,8 +755,6 @@ def initialize(
|
|||
if max_budget: # litellm-specific param
|
||||
litellm.max_budget = max_budget
|
||||
dynamic_config["general"]["max_budget"] = max_budget
|
||||
if use_queue:
|
||||
celery_setup(use_queue=use_queue)
|
||||
if experimental:
|
||||
pass
|
||||
user_telemetry = telemetry
|
||||
|
@ -798,48 +790,12 @@ def data_generator(response):
|
|||
async def async_data_generator(response, user_api_key_dict):
|
||||
print_verbose("inside generator")
|
||||
async for chunk in response:
|
||||
# try:
|
||||
# await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=None, call_type="completion")
|
||||
# except Exception as e:
|
||||
# print(f"An exception occurred - {str(e)}")
|
||||
|
||||
print_verbose(f"returned chunk: {chunk}")
|
||||
try:
|
||||
yield f"data: {json.dumps(chunk.dict())}\n\n"
|
||||
except:
|
||||
yield f"data: {json.dumps(chunk)}\n\n"
|
||||
|
||||
def litellm_completion(*args, **kwargs):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
call_type = kwargs.pop("call_type")
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
kwargs["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
kwargs["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
kwargs["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
kwargs["api_base"] = user_api_base
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
try:
|
||||
if llm_router is not None and kwargs["model"] in router_model_names: # model in router model list
|
||||
if call_type == "chat_completion":
|
||||
response = llm_router.completion(*args, **kwargs)
|
||||
elif call_type == "text_completion":
|
||||
response = llm_router.text_completion(*args, **kwargs)
|
||||
else:
|
||||
if call_type == "chat_completion":
|
||||
response = litellm.completion(*args, **kwargs)
|
||||
elif call_type == "text_completion":
|
||||
response = litellm.text_completion(*args, **kwargs)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if 'stream' in kwargs and kwargs['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
def get_litellm_model_info(model: dict = {}):
|
||||
model_info = model.get("model_info", {})
|
||||
model_to_lookup = model.get("litellm_params", {}).get("model", None)
|
||||
|
@ -870,6 +826,8 @@ async def startup_event():
|
|||
initialize(**worker_config)
|
||||
|
||||
|
||||
proxy_logging_obj._init_litellm_callbacks() # INITIALIZE LITELLM CALLBACKS ON SERVER STARTUP <- do this to catch any logging errors on startup, not when calls are being made
|
||||
|
||||
if use_background_health_checks:
|
||||
asyncio.create_task(_run_background_health_check()) # start the background health check coroutine.
|
||||
|
||||
|
@ -881,16 +839,6 @@ async def startup_event():
|
|||
# add master key to db
|
||||
await generate_key_helper_fn(duration=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global prisma_client, master_key, user_custom_auth
|
||||
if prisma_client:
|
||||
print("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
master_key = None
|
||||
user_custom_auth = None
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||
|
@ -929,7 +877,8 @@ def model_list():
|
|||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -938,7 +887,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["user"] = data.get("user", user_api_key_dict.user_id)
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -947,17 +896,44 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
)
|
||||
if user_model:
|
||||
data["model"] = user_model
|
||||
data["call_type"] = "text_completion"
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
return litellm_completion(
|
||||
**data
|
||||
)
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.atext_completion(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.atext_completion(**data, specific_deployment = True)
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.atext_completion(**data)
|
||||
else: # router is not set
|
||||
response = await litellm.atext_completion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
print(f"\033[1;31mAn error occurred: {e}\n\n Debug this by setting `--debug`, e.g. `litellm --model gpt-3.5-turbo --debug`")
|
||||
traceback.print_exc()
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
|
@ -995,7 +971,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
)
|
||||
|
||||
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||
if data.get("user", None) is None:
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
# if users are using user_api_key_auth, set `user` in `data`
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
|
@ -1027,7 +1003,7 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
response = await llm_router.acompletion(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.acompletion(**data, specific_deployment = True)
|
||||
elif llm_router is not None and litellm.model_group_alias_map is not None and data["model"] in litellm.model_group_alias_map: # model set in model_group_alias_map
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.acompletion(**data)
|
||||
else: # router is not set
|
||||
response = await litellm.acompletion(**data)
|
||||
|
@ -1088,7 +1064,9 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
"body": copy.copy(data) # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -1098,10 +1076,11 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
data["model"] = user_model
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
print(f"received data: {data['input']}")
|
||||
if "input" in data and isinstance(data['input'], list) and isinstance(data['input'][0], list) and isinstance(data['input'][0][0], int): # check if array of tokens passed in
|
||||
# check if non-openai/azure model called - e.g. for langchain integration
|
||||
if llm_model_list is not None and data["model"] in router_model_names:
|
||||
|
@ -1119,12 +1098,13 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
|
||||
### CALL HOOKS ### - modify incoming data / reject request before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="embeddings")
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.aembedding(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.aembedding(**data, specific_deployment = True)
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.aembedding(**data) # ensure this goes the llm_router, router will do the correct alias mapping
|
||||
else:
|
||||
response = await litellm.aembedding(**data)
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
|
@ -1133,7 +1113,19 @@ async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depen
|
|||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
else:
|
||||
error_traceback = traceback.format_exc()
|
||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
try:
|
||||
status = e.status_code # type: ignore
|
||||
except:
|
||||
status = 500
|
||||
raise HTTPException(
|
||||
status_code=status,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
#### KEY MANAGEMENT ####
|
||||
|
||||
|
@ -1162,6 +1154,30 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest, Authorizat
|
|||
response = await generate_key_helper_fn(**data_json)
|
||||
return GenerateKeyResponse(key=response["token"], expires=response["expires"], user_id=response["user_id"])
|
||||
|
||||
@router.post("/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def update_key_fn(request: Request, data: UpdateKeyRequest):
|
||||
"""
|
||||
Update an existing key
|
||||
"""
|
||||
global prisma_client
|
||||
try:
|
||||
data_json: dict = data.json()
|
||||
key = data_json.pop("key")
|
||||
# get the row from db
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
||||
print(f"non_default_values: {non_default_values}")
|
||||
response = await prisma_client.update_data(token=key, data={**non_default_values, "token": key})
|
||||
return {"key": key, **non_default_values}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||
try:
|
||||
|
@ -1207,10 +1223,12 @@ async def add_new_model(model_params: ModelParams):
|
|||
|
||||
print_verbose(f"Loaded config: {config}")
|
||||
# Add the new model to the config
|
||||
model_info = model_params.model_info.json()
|
||||
model_info = {k: v for k, v in model_info.items() if v is not None}
|
||||
config['model_list'].append({
|
||||
'model_name': model_params.model_name,
|
||||
'litellm_params': model_params.litellm_params,
|
||||
'model_info': model_params.model_info
|
||||
'model_info': model_info
|
||||
})
|
||||
|
||||
# Save the updated config
|
||||
|
@ -1228,7 +1246,7 @@ async def add_new_model(model_params: ModelParams):
|
|||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback https://github.com/BerriAI/litellm/issues/933. If you need a stable endpoint use /model/info
|
||||
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def model_info_v1(request: Request):
|
||||
global llm_model_list, general_settings, user_config_file_path
|
||||
# Load existing config
|
||||
|
@ -1256,7 +1274,7 @@ async def model_info_v1(request: Request):
|
|||
|
||||
|
||||
#### [BETA] - This is a beta endpoint, format might change based on user feedback. - https://github.com/BerriAI/litellm/issues/933
|
||||
@router.get("/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/v1/model/info", description="Provides more info about each model in /models, including config.yaml descriptions (except api key and api base)", tags=["model management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def model_info(request: Request):
|
||||
global llm_model_list, general_settings, user_config_file_path
|
||||
# Load existing config
|
||||
|
@ -1341,47 +1359,108 @@ async def delete_model(model_info: ModelInfoDelete):
|
|||
raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
|
||||
|
||||
#### EXPERIMENTAL QUEUING ####
|
||||
@router.post("/queue/request", dependencies=[Depends(user_api_key_auth)])
|
||||
async def async_queue_request(request: Request):
|
||||
global celery_fn, llm_model_list
|
||||
if celery_fn is not None:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
try:
|
||||
data = ast.literal_eval(body_str)
|
||||
except:
|
||||
data = json.loads(body_str)
|
||||
async def _litellm_chat_completions_worker(data, user_api_key_dict):
|
||||
"""
|
||||
worker to make litellm completions calls
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
### CALL HOOKS ### - modify incoming data before calling the model
|
||||
data = await proxy_logging_obj.pre_call_hook(user_api_key_dict=user_api_key_dict, data=data, call_type="completion")
|
||||
|
||||
print(f"_litellm_chat_completions_worker started")
|
||||
### ROUTE THE REQUEST ###
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
if llm_router is not None and data["model"] in router_model_names: # model in router model list
|
||||
response = await llm_router.acompletion(**data)
|
||||
elif llm_router is not None and data["model"] in llm_router.deployment_names: # model in router deployments, calling a specific deployment on the router
|
||||
response = await llm_router.acompletion(**data, specific_deployment = True)
|
||||
elif llm_router is not None and llm_router.model_group_alias is not None and data["model"] in llm_router.model_group_alias: # model set in model_group_alias
|
||||
response = await llm_router.acompletion(**data)
|
||||
else: # router is not set
|
||||
response = await litellm.acompletion(**data)
|
||||
|
||||
print(f"final response: {response}")
|
||||
return response
|
||||
except HTTPException as e:
|
||||
print(f"EXCEPTION RAISED IN _litellm_chat_completions_worker - {e.status_code}; {e.detail}")
|
||||
if e.status_code == 429 and "Max parallel request limit reached" in e.detail:
|
||||
print(f"Max parallel request limit reached!")
|
||||
timeout = litellm._calculate_retry_after(remaining_retries=3, max_retries=3, min_timeout=1)
|
||||
await asyncio.sleep(timeout)
|
||||
else:
|
||||
raise e
|
||||
|
||||
|
||||
@router.post("/queue/chat/completions", tags=["experimental"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def async_queue_request(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global general_settings, user_debug, proxy_logging_obj
|
||||
"""
|
||||
v2 attempt at a background worker to handle queuing.
|
||||
|
||||
Just supports /chat/completion calls currently.
|
||||
|
||||
Now using a FastAPI background task + /chat/completions compatible endpoint
|
||||
"""
|
||||
try:
|
||||
data = {}
|
||||
data = await request.json() # type: ignore
|
||||
|
||||
# Include original request and headers in the data
|
||||
data["proxy_server_request"] = {
|
||||
"url": str(request.url),
|
||||
"method": request.method,
|
||||
"headers": dict(request.headers),
|
||||
"body": copy.copy(data) # use copy instead of deepcopy
|
||||
}
|
||||
|
||||
print_verbose(f"receiving data: {data}")
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
or model # for azure deployments
|
||||
or data["model"] # default passed in http request
|
||||
)
|
||||
data["llm_model_list"] = llm_model_list
|
||||
print(f"data: {data}")
|
||||
job = celery_fn.apply_async(kwargs=data)
|
||||
return {"id": job.id, "url": f"/queue/response/{job.id}", "eta": 5, "status": "queued"}
|
||||
else:
|
||||
|
||||
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||
if data.get("user", None) is None and user_api_key_dict.user_id is not None:
|
||||
# if users are using user_api_key_auth, set `user` in `data`
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if "metadata" in data:
|
||||
print(f'received metadata: {data["metadata"]}')
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
data["metadata"]["headers"] = dict(request.headers)
|
||||
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
# override with user settings, these are params passed via cli
|
||||
if user_temperature:
|
||||
data["temperature"] = user_temperature
|
||||
if user_request_timeout:
|
||||
data["request_timeout"] = user_request_timeout
|
||||
if user_max_tokens:
|
||||
data["max_tokens"] = user_max_tokens
|
||||
if user_api_base:
|
||||
data["api_base"] = user_api_base
|
||||
|
||||
response = await asyncio.wait_for(_litellm_chat_completions_worker(data=data, user_api_key_dict=user_api_key_dict), timeout=litellm.request_timeout)
|
||||
|
||||
if 'stream' in data and data['stream'] == True: # use generate_responses to stream responses
|
||||
return StreamingResponse(async_data_generator(user_api_key_dict=user_api_key_dict, response=response), media_type='text/event-stream')
|
||||
|
||||
background_tasks.add_task(log_input_output, request, response) # background task for logging to OTEL
|
||||
return response
|
||||
except Exception as e:
|
||||
await proxy_logging_obj.post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "Queue not initialized"},
|
||||
detail={"error": str(e)},
|
||||
)
|
||||
|
||||
@router.get("/queue/response/{task_id}", dependencies=[Depends(user_api_key_auth)])
|
||||
async def async_queue_response(request: Request, task_id: str):
|
||||
global celery_app_conn, async_result
|
||||
try:
|
||||
if celery_app_conn is not None and async_result is not None:
|
||||
job = async_result(task_id, app=celery_app_conn)
|
||||
if job.ready():
|
||||
return {"status": "finished", "result": job.result}
|
||||
else:
|
||||
return {'status': 'queued'}
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
return {"status": "finished", "result": str(e)}
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||
async def retrieve_server_log(request: Request):
|
||||
filepath = os.path.expanduser("~/.ollama/logs/server.log")
|
||||
|
@ -1411,8 +1490,18 @@ async def config_yaml_endpoint(config_info: ConfigYAML):
|
|||
return {"hello": "world"}
|
||||
|
||||
|
||||
@router.get("/test")
|
||||
@router.get("/test", tags=["health"])
|
||||
async def test_endpoint(request: Request):
|
||||
"""
|
||||
A test endpoint that pings the proxy server to check if it's healthy.
|
||||
|
||||
Parameters:
|
||||
request (Request): The incoming request.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the route of the request URL.
|
||||
"""
|
||||
# ping the proxy server to check if its healthy
|
||||
return {"route": request.url.path}
|
||||
|
||||
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||
|
@ -1470,4 +1559,27 @@ async def get_routes():
|
|||
return {"routes": routes}
|
||||
|
||||
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global prisma_client, master_key, user_custom_auth
|
||||
if prisma_client:
|
||||
print("Disconnecting from Prisma")
|
||||
await prisma_client.disconnect()
|
||||
|
||||
## RESET CUSTOM VARIABLES ##
|
||||
cleanup_router_config_variables()
|
||||
|
||||
def cleanup_router_config_variables():
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, use_background_health_checks, health_check_interval
|
||||
|
||||
# Set all variables to None
|
||||
master_key = None
|
||||
user_config_file_path = None
|
||||
otel_logging = None
|
||||
user_custom_auth = None
|
||||
user_custom_auth_path = None
|
||||
use_background_health_checks = None
|
||||
health_check_interval = None
|
||||
|
||||
|
||||
app.include_router(router)
|
||||
|
|
|
@ -1,23 +0,0 @@
|
|||
import openai
|
||||
client = openai.OpenAI(
|
||||
api_key="anything",
|
||||
# base_url="http://0.0.0.0:8000",
|
||||
)
|
||||
|
||||
try:
|
||||
# request sent to model set on litellm proxy, `litellm --model`
|
||||
response = client.chat.completions.create(model="gpt-3.5-turbo", messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "this is a test request, write a short poem"
|
||||
},
|
||||
])
|
||||
|
||||
print(response)
|
||||
# except openai.APITimeoutError:
|
||||
# print("Got openai Timeout Exception. Good job. The proxy mapped to OpenAI exceptions")
|
||||
except Exception as e:
|
||||
print("\n the proxy did not map to OpenAI exception. Instead got", e)
|
||||
print(e.type) # type: ignore
|
||||
print(e.message) # type: ignore
|
||||
print(e.code) # type: ignore
|
|
@ -1,13 +1,13 @@
|
|||
from typing import Optional, List, Any, Literal
|
||||
import os, subprocess, hashlib, importlib, asyncio
|
||||
import os, subprocess, hashlib, importlib, asyncio, copy
|
||||
import litellm, backoff
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy.hooks.parallel_request_limiter import MaxParallelRequestsHandler
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
def print_verbose(print_statement):
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
print(f"LiteLLM Proxy: {print_statement}") # noqa
|
||||
### LOGGING ###
|
||||
class ProxyLogging:
|
||||
"""
|
||||
|
@ -26,7 +26,7 @@ class ProxyLogging:
|
|||
pass
|
||||
|
||||
def _init_litellm_callbacks(self):
|
||||
|
||||
print_verbose(f"INITIALIZING LITELLM CALLBACKS!")
|
||||
litellm.callbacks.append(self.max_parallel_request_limiter)
|
||||
for callback in litellm.callbacks:
|
||||
if callback not in litellm.input_callback:
|
||||
|
@ -64,18 +64,14 @@ class ProxyLogging:
|
|||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
try:
|
||||
self.call_details["data"] = data
|
||||
self.call_details["call_type"] = call_type
|
||||
try:
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and 'async_pre_call_hook' in vars(callback.__class__):
|
||||
response = await callback.async_pre_call_hook(user_api_key_dict=user_api_key_dict, cache=self.call_details["user_api_key_cache"], data=data, call_type=call_type)
|
||||
if response is not None:
|
||||
data = response
|
||||
|
||||
## check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## if set, check if request allowed
|
||||
await self.max_parallel_request_limiter.max_parallel_request_allow_request(
|
||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
print_verbose(f'final data being sent to {call_type} call: {data}')
|
||||
return data
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
@ -103,17 +99,13 @@ class ProxyLogging:
|
|||
1. /chat/completions
|
||||
2. /embeddings
|
||||
"""
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
and "Max parallel request limit reached" in str(original_exception)):
|
||||
pass # ignore failed calls due to max limit being reached
|
||||
else:
|
||||
await self.max_parallel_request_limiter.async_log_failure_call(
|
||||
api_key=user_api_key_dict.api_key,
|
||||
user_api_key_cache=self.call_details["user_api_key_cache"])
|
||||
|
||||
for callback in litellm.callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
await callback.async_post_call_failure_hook(user_api_key_dict=user_api_key_dict, original_exception=original_exception)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return
|
||||
|
||||
|
||||
|
@ -165,19 +157,20 @@ class PrismaClient:
|
|||
async def get_data(self, token: str, expires: Optional[Any]=None):
|
||||
try:
|
||||
# check if plain text or hash
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
hashed_token = self.hash_token(token=token)
|
||||
if expires:
|
||||
response = await self.db.litellm_verificationtoken.find_first(
|
||||
where={
|
||||
"token": token,
|
||||
"token": hashed_token,
|
||||
"expires": {"gte": expires} # Check if the token is not expired
|
||||
}
|
||||
)
|
||||
else:
|
||||
response = await self.db.litellm_verificationtoken.find_unique(
|
||||
where={
|
||||
"token": token
|
||||
"token": hashed_token
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
@ -200,18 +193,18 @@ class PrismaClient:
|
|||
try:
|
||||
token = data["token"]
|
||||
hashed_token = self.hash_token(token=token)
|
||||
data["token"] = hashed_token
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = hashed_token
|
||||
|
||||
new_verification_token = await self.db.litellm_verificationtoken.upsert( # type: ignore
|
||||
where={
|
||||
'token': hashed_token,
|
||||
},
|
||||
data={
|
||||
"create": {**data}, #type: ignore
|
||||
"create": {**db_data}, #type: ignore
|
||||
"update": {} # don't do anything if it already exists
|
||||
}
|
||||
)
|
||||
|
||||
return new_verification_token
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
|
@ -235,15 +228,16 @@ class PrismaClient:
|
|||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
|
||||
data["token"] = token
|
||||
db_data = copy.deepcopy(data)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={
|
||||
"token": token
|
||||
},
|
||||
data={**data} # type: ignore
|
||||
data={**db_data} # type: ignore
|
||||
)
|
||||
print_verbose("\033[91m" + f"DB write succeeded {response}" + "\033[0m")
|
||||
return {"token": token, "data": data}
|
||||
return {"token": token, "data": db_data}
|
||||
except Exception as e:
|
||||
asyncio.create_task(self.proxy_logging_obj.failure_handler(original_exception=e))
|
||||
print_verbose("\033[91m" + f"DB write failed: {e}" + "\033[0m")
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#
|
||||
# Thank you ! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import copy
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Union, Literal, Any
|
||||
import random, threading, time, traceback, uuid
|
||||
|
@ -17,6 +18,7 @@ import inspect, concurrent
|
|||
from openai import AsyncOpenAI
|
||||
from collections import defaultdict
|
||||
from litellm.router_strategy.least_busy import LeastBusyLoggingHandler
|
||||
import copy
|
||||
class Router:
|
||||
"""
|
||||
Example usage:
|
||||
|
@ -68,6 +70,7 @@ class Router:
|
|||
redis_password: Optional[str] = None,
|
||||
cache_responses: Optional[bool] = False,
|
||||
cache_kwargs: dict = {}, # additional kwargs to pass to RedisCache (see caching.py)
|
||||
caching_groups: Optional[List[tuple]] = None, # if you want to cache across model groups
|
||||
## RELIABILITY ##
|
||||
num_retries: int = 0,
|
||||
timeout: Optional[float] = None,
|
||||
|
@ -76,11 +79,13 @@ class Router:
|
|||
fallbacks: List = [],
|
||||
allowed_fails: Optional[int] = None,
|
||||
context_window_fallbacks: List = [],
|
||||
model_group_alias: Optional[dict] = {},
|
||||
routing_strategy: Literal["simple-shuffle", "least-busy", "usage-based-routing", "latency-based-routing"] = "simple-shuffle") -> None:
|
||||
|
||||
self.set_verbose = set_verbose
|
||||
self.deployment_names: List = [] # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||
if model_list:
|
||||
model_list = copy.deepcopy(model_list)
|
||||
self.set_model_list(model_list)
|
||||
self.healthy_deployments: List = self.model_list
|
||||
self.deployment_latency_map = {}
|
||||
|
@ -99,6 +104,7 @@ class Router:
|
|||
self.fail_calls: defaultdict = defaultdict(int) # dict to store fail_calls made to each model
|
||||
self.success_calls: defaultdict = defaultdict(int) # dict to store success_calls made to each model
|
||||
self.previous_models: List = [] # list to store failed calls (passed in as metadata to next call)
|
||||
self.model_group_alias: dict = model_group_alias or {} # dict to store aliases for router, ex. {"gpt-4": "gpt-3.5-turbo"}, all requests with gpt-4 -> get routed to gpt-3.5-turbo group
|
||||
|
||||
# make Router.chat.completions.create compatible for openai.chat.completions.create
|
||||
self.chat = litellm.Chat(params=default_litellm_params)
|
||||
|
@ -107,9 +113,10 @@ class Router:
|
|||
self.default_litellm_params = default_litellm_params
|
||||
self.default_litellm_params.setdefault("timeout", timeout)
|
||||
self.default_litellm_params.setdefault("max_retries", 0)
|
||||
self.default_litellm_params.setdefault("metadata", {}).update({"caching_groups": caching_groups})
|
||||
|
||||
### CACHING ###
|
||||
cache_type = "local" # default to an in-memory cache
|
||||
cache_type: Literal["local", "redis"] = "local" # default to an in-memory cache
|
||||
redis_cache = None
|
||||
cache_config = {}
|
||||
if redis_url is not None or (redis_host is not None and redis_port is not None and redis_password is not None):
|
||||
|
@ -133,7 +140,7 @@ class Router:
|
|||
if cache_responses:
|
||||
if litellm.cache is None:
|
||||
# the cache can be initialized on the proxy server. We should not overwrite it
|
||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config)
|
||||
litellm.cache = litellm.Cache(type=cache_type, **cache_config) # type: ignore
|
||||
self.cache_responses = cache_responses
|
||||
self.cache = DualCache(redis_cache=redis_cache, in_memory_cache=InMemoryCache()) # use a dual cache (Redis+In-Memory) for tracking cooldowns, usage, etc.
|
||||
### ROUTING SETUP ###
|
||||
|
@ -198,19 +205,10 @@ class Router:
|
|||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
self.print_verbose(f"completion model: {original_model_string}")
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
return litellm.completion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
except Exception as e:
|
||||
|
@ -241,31 +239,25 @@ class Router:
|
|||
**kwargs):
|
||||
try:
|
||||
self.print_verbose(f"Inside _acompletion()- model: {model}; kwargs: {kwargs}")
|
||||
original_model_string = None # set a default for this variable
|
||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
data = deployment["litellm_params"].copy()
|
||||
model_name = data["model"]
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
self.total_calls[original_model_string] +=1
|
||||
self.total_calls[model_name] +=1
|
||||
response = await litellm.acompletion(**{**data, "messages": messages, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
self.success_calls[original_model_string] +=1
|
||||
self.success_calls[model_name] +=1
|
||||
return response
|
||||
except Exception as e:
|
||||
if original_model_string is not None:
|
||||
self.fail_calls[original_model_string] +=1
|
||||
if model_name is not None:
|
||||
self.fail_calls[model_name] +=1
|
||||
raise e
|
||||
|
||||
def text_completion(self,
|
||||
|
@ -283,8 +275,43 @@ class Router:
|
|||
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
# call via litellm.completion()
|
||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
kwargs["messages"] = messages
|
||||
kwargs["original_exception"] = e
|
||||
kwargs["original_function"] = self.completion
|
||||
return self.function_with_retries(**kwargs)
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def atext_completion(self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
is_retry: Optional[bool] = False,
|
||||
is_fallback: Optional[bool] = False,
|
||||
is_async: Optional[bool] = False,
|
||||
**kwargs):
|
||||
try:
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model})
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, messages=messages, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
|
@ -294,8 +321,9 @@ class Router:
|
|||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
# call via litellm.completion()
|
||||
return litellm.text_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
# call via litellm.atext_completion()
|
||||
response = await litellm.atext_completion(**{**data, "prompt": prompt, "caching": self.cache_responses, **kwargs}) # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
if self.num_retries > 0:
|
||||
kwargs["model"] = model
|
||||
|
@ -313,21 +341,14 @@ class Router:
|
|||
**kwargs) -> Union[List[float], None]:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
kwargs.setdefault("model_info", {})
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
|
||||
data = deployment["litellm_params"].copy()
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs)
|
||||
# call via litellm.embedding()
|
||||
return litellm.embedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
@ -339,21 +360,15 @@ class Router:
|
|||
**kwargs) -> Union[List[float], None]:
|
||||
# pick the one that is available (lowest TPM/RPM)
|
||||
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
|
||||
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]})
|
||||
kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]})
|
||||
data = deployment["litellm_params"].copy()
|
||||
kwargs["model_info"] = deployment.get("model_info", {})
|
||||
for k, v in self.default_litellm_params.items():
|
||||
if k not in data: # prioritize model-specific params > default router params
|
||||
data[k] = v
|
||||
########## remove -ModelID-XXXX from model ##############
|
||||
original_model_string = data["model"]
|
||||
# Find the index of "ModelID" in the string
|
||||
index_of_model_id = original_model_string.find("-ModelID")
|
||||
# Remove everything after "-ModelID" if it exists
|
||||
if index_of_model_id != -1:
|
||||
data["model"] = original_model_string[:index_of_model_id]
|
||||
else:
|
||||
data["model"] = original_model_string
|
||||
if k not in kwargs: # prioritize model-specific params > default router params
|
||||
kwargs[k] = v
|
||||
elif k == "metadata":
|
||||
kwargs[k].update(v)
|
||||
|
||||
model_client = self._get_client(deployment=deployment, kwargs=kwargs, client_type="async")
|
||||
|
||||
return await litellm.aembedding(**{**data, "input": input, "caching": self.cache_responses, "client": model_client, **kwargs})
|
||||
|
@ -371,7 +386,7 @@ class Router:
|
|||
self.print_verbose(f'Async Response: {response}')
|
||||
return response
|
||||
except Exception as e:
|
||||
self.print_verbose(f"An exception occurs: {e}")
|
||||
self.print_verbose(f"An exception occurs: {e}\n\n Traceback{traceback.format_exc()}")
|
||||
original_exception = e
|
||||
try:
|
||||
self.print_verbose(f"Trying to fallback b/w models")
|
||||
|
@ -637,9 +652,10 @@ class Router:
|
|||
model_name = kwargs.get('model', None) # i.e. gpt35turbo
|
||||
custom_llm_provider = kwargs.get("litellm_params", {}).get('custom_llm_provider', None) # i.e. azure
|
||||
metadata = kwargs.get("litellm_params", {}).get('metadata', None)
|
||||
deployment_id = kwargs.get("litellm_params", {}).get("model_info").get("id")
|
||||
self._set_cooldown_deployments(deployment_id) # setting deployment_id in cooldown deployments
|
||||
if metadata:
|
||||
deployment = metadata.get("deployment", None)
|
||||
self._set_cooldown_deployments(deployment)
|
||||
deployment_exceptions = self.model_exception_map.get(deployment, [])
|
||||
deployment_exceptions.append(exception_str)
|
||||
self.model_exception_map[deployment] = deployment_exceptions
|
||||
|
@ -877,7 +893,7 @@ class Router:
|
|||
return chosen_item
|
||||
|
||||
def set_model_list(self, model_list: list):
|
||||
self.model_list = model_list
|
||||
self.model_list = copy.deepcopy(model_list)
|
||||
# we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works
|
||||
import os
|
||||
for model in self.model_list:
|
||||
|
@ -889,23 +905,26 @@ class Router:
|
|||
model["model_info"] = model_info
|
||||
#### for OpenAI / Azure we need to initalize the Client for High Traffic ########
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider")
|
||||
if custom_llm_provider is None:
|
||||
custom_llm_provider = model_name.split("/",1)[0]
|
||||
custom_llm_provider = custom_llm_provider or model_name.split("/",1)[0] or ""
|
||||
default_api_base = None
|
||||
default_api_key = None
|
||||
if custom_llm_provider in litellm.openai_compatible_providers:
|
||||
_, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(model=model_name)
|
||||
default_api_base = api_base
|
||||
default_api_key = api_key
|
||||
if (
|
||||
model_name in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "deepinfra"
|
||||
or custom_llm_provider == "perplexity"
|
||||
or custom_llm_provider == "anyscale"
|
||||
or custom_llm_provider == "openai"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
or custom_llm_provider == "azure"
|
||||
or custom_llm_provider == "custom_openai"
|
||||
or custom_llm_provider == "openai"
|
||||
or "ft:gpt-3.5-turbo" in model_name
|
||||
or model_name in litellm.open_ai_embedding_models
|
||||
):
|
||||
# glorified / complicated reading of configs
|
||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||
api_key = litellm_params.get("api_key")
|
||||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = litellm.get_secret(api_key_env_name)
|
||||
|
@ -913,7 +932,7 @@ class Router:
|
|||
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url = litellm_params.get("base_url")
|
||||
api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure
|
||||
api_base = api_base or base_url or default_api_base # allow users to pass in `api_base` or `base_url` for azure
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = litellm.get_secret(api_base_env_name)
|
||||
|
@ -1049,12 +1068,6 @@ class Router:
|
|||
|
||||
############ End of initializing Clients for OpenAI/Azure ###################
|
||||
self.deployment_names.append(model["litellm_params"]["model"])
|
||||
model_id = ""
|
||||
for key in model["litellm_params"]:
|
||||
if key != "api_key" and key != "metadata":
|
||||
model_id+= str(model["litellm_params"][key])
|
||||
model["litellm_params"]["model"] += "-ModelID-" + model_id
|
||||
|
||||
self.print_verbose(f"\n Initialized Model List {self.model_list}")
|
||||
|
||||
############ Users can either pass tpm/rpm as a litellm_param or a router param ###########
|
||||
|
@ -1115,38 +1128,41 @@ class Router:
|
|||
if specific_deployment == True:
|
||||
# users can also specify a specific deployment name. At this point we should check if they are just trying to call a specific deployment
|
||||
for deployment in self.model_list:
|
||||
cleaned_model = litellm.utils.remove_model_id(deployment.get("litellm_params").get("model"))
|
||||
if cleaned_model == model:
|
||||
deployment_model = deployment.get("litellm_params").get("model")
|
||||
if deployment_model == model:
|
||||
# User Passed a specific deployment name on their config.yaml, example azure/chat-gpt-v-2
|
||||
# return the first deployment where the `model` matches the specificed deployment name
|
||||
return deployment
|
||||
raise ValueError(f"LiteLLM Router: Trying to call specific deployment, but Model:{model} does not exist in Model List: {self.model_list}")
|
||||
|
||||
# check if aliases set on litellm model alias map
|
||||
if model in litellm.model_group_alias_map:
|
||||
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {litellm.model_group_alias_map.get(model)}")
|
||||
model = litellm.model_group_alias_map[model]
|
||||
if model in self.model_group_alias:
|
||||
self.print_verbose(f"Using a model alias. Got Request for {model}, sending requests to {self.model_group_alias.get(model)}")
|
||||
model = self.model_group_alias[model]
|
||||
|
||||
## get healthy deployments
|
||||
### get all deployments
|
||||
### filter out the deployments currently cooling down
|
||||
healthy_deployments = [m for m in self.model_list if m["model_name"] == model]
|
||||
if len(healthy_deployments) == 0:
|
||||
# check if the user sent in a deployment name instead
|
||||
healthy_deployments = [m for m in self.model_list if m["litellm_params"]["model"] == model]
|
||||
|
||||
self.print_verbose(f"initial list of deployments: {healthy_deployments}")
|
||||
|
||||
# filter out the deployments currently cooling down
|
||||
deployments_to_remove = []
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
# cooldown_deployments is a list of model_id's cooling down, cooldown_deployments = ["16700539-b3cd-42f4-b426-6a12a1bb706a", "16700539-b3cd-42f4-b426-7899"]
|
||||
cooldown_deployments = self._get_cooldown_deployments()
|
||||
self.print_verbose(f"cooldown deployments: {cooldown_deployments}")
|
||||
### FIND UNHEALTHY DEPLOYMENTS
|
||||
# Find deployments in model_list whose model_id is cooling down
|
||||
for deployment in healthy_deployments:
|
||||
deployment_name = deployment["litellm_params"]["model"]
|
||||
if deployment_name in cooldown_deployments:
|
||||
deployment_id = deployment["model_info"]["id"]
|
||||
if deployment_id in cooldown_deployments:
|
||||
deployments_to_remove.append(deployment)
|
||||
### FILTER OUT UNHEALTHY DEPLOYMENTS
|
||||
# remove unhealthy deployments from healthy deployments
|
||||
for deployment in deployments_to_remove:
|
||||
healthy_deployments.remove(deployment)
|
||||
|
||||
self.print_verbose(f"healthy deployments: length {len(healthy_deployments)} {healthy_deployments}")
|
||||
if len(healthy_deployments) == 0:
|
||||
raise ValueError("No models available")
|
||||
|
@ -1222,11 +1238,14 @@ class Router:
|
|||
raise ValueError("No models available.")
|
||||
|
||||
def flush_cache(self):
|
||||
litellm.cache = None
|
||||
self.cache.flush_cache()
|
||||
|
||||
def reset(self):
|
||||
## clean up on close
|
||||
litellm.success_callback = []
|
||||
litellm.__async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm._async_failure_callback = []
|
||||
self.flush_cache()
|
||||
|
34
litellm/tests/conftest.py
Normal file
34
litellm/tests/conftest.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# conftest.py
|
||||
|
||||
import pytest, sys, os
|
||||
import importlib
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown():
|
||||
"""
|
||||
This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained.
|
||||
"""
|
||||
curr_dir = os.getcwd() # Get the current working directory
|
||||
sys.path.insert(0, os.path.abspath("../..")) # Adds the project directory to the system path
|
||||
import litellm
|
||||
importlib.reload(litellm)
|
||||
print(litellm)
|
||||
# from litellm import Router, completion, aembedding, acompletion, embedding
|
||||
yield
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
# Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests
|
||||
custom_logger_tests = [item for item in items if 'custom_logger' in item.parent.name]
|
||||
other_tests = [item for item in items if 'custom_logger' not in item.parent.name]
|
||||
|
||||
# Sort tests based on their names
|
||||
custom_logger_tests.sort(key=lambda x: x.name)
|
||||
other_tests.sort(key=lambda x: x.name)
|
||||
|
||||
# Reorder the items list
|
||||
items[:] = custom_logger_tests + other_tests
|
7
litellm/tests/example_config_yaml/cache_no_params.yaml
Normal file
7
litellm/tests/example_config_yaml/cache_no_params.yaml
Normal file
|
@ -0,0 +1,7 @@
|
|||
model_list:
|
||||
- model_name: "openai-model"
|
||||
litellm_params:
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
litellm_settings:
|
||||
cache: True
|
10
litellm/tests/example_config_yaml/cache_with_params.yaml
Normal file
10
litellm/tests/example_config_yaml/cache_with_params.yaml
Normal file
|
@ -0,0 +1,10 @@
|
|||
model_list:
|
||||
- model_name: "openai-model"
|
||||
litellm_params:
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
litellm_settings:
|
||||
cache: True
|
||||
cache_params:
|
||||
supported_call_types: ["embedding", "aembedding"]
|
||||
host: "localhost"
|
4
litellm/tests/langfuse.log
Normal file
4
litellm/tests/langfuse.log
Normal file
|
@ -0,0 +1,4 @@
|
|||
uploading batch of 2 items
|
||||
successfully uploaded batch of 2 items
|
||||
uploading batch of 2 items
|
||||
successfully uploaded batch of 2 items
|
|
@ -118,6 +118,7 @@ def test_cooldown_same_model_name():
|
|||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": "BAD_API_BASE",
|
||||
"tpm": 90
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@ -126,7 +127,8 @@ def test_cooldown_same_model_name():
|
|||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"tpm": 0.000001
|
||||
},
|
||||
},
|
||||
]
|
||||
|
@ -151,13 +153,14 @@ def test_cooldown_same_model_name():
|
|||
]
|
||||
)
|
||||
print(router.model_list)
|
||||
litellm_model_names = []
|
||||
model_ids = []
|
||||
for model in router.model_list:
|
||||
litellm_model_names.append(model["litellm_params"]["model"])
|
||||
print("\n litellm model names ", litellm_model_names)
|
||||
model_ids.append(model["model_info"]["id"])
|
||||
print("\n litellm model ids ", model_ids)
|
||||
|
||||
# example litellm_model_names ['azure/chatgpt-v-2-ModelID-64321', 'azure/chatgpt-v-2-ModelID-63960']
|
||||
assert litellm_model_names[0] != litellm_model_names[1] # ensure both models have a uuid added, and they have different names
|
||||
assert model_ids[0] != model_ids[1] # ensure both models have a uuid added, and they have different names
|
||||
|
||||
print("\ngot response\n", response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Got unexpected exception on router! - {e}")
|
||||
|
|
|
@ -9,9 +9,9 @@ import os, io
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import pytest, asyncio
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import embedding, completion, completion_cost, Timeout, acompletion
|
||||
from litellm import RateLimitError
|
||||
import json
|
||||
import os
|
||||
|
@ -63,6 +63,27 @@ def load_vertex_ai_credentials():
|
|||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = os.path.abspath(temp_file.name)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def get_response():
|
||||
load_vertex_ai_credentials()
|
||||
prompt = '\ndef count_nums(arr):\n """\n Write a function count_nums which takes an array of integers and returns\n the number of elements which has a sum of digits > 0.\n If a number is negative, then its first signed digit will be negative:\n e.g. -123 has signed digits -1, 2, and 3.\n >>> count_nums([]) == 0\n >>> count_nums([-1, 11, -11]) == 1\n >>> count_nums([1, 1, 2]) == 3\n """\n'
|
||||
try:
|
||||
response = await acompletion(
|
||||
model="gemini-pro",
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Complete the given code with no more explanation. Remember that there is a 4-space indent before the first line of your generated code.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
return response
|
||||
except litellm.UnprocessableEntityError as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
|
||||
|
||||
def test_vertex_ai():
|
||||
import random
|
||||
|
@ -72,14 +93,15 @@ def test_vertex_ai():
|
|||
litellm.set_verbose=False
|
||||
litellm.vertex_project = "hardy-device-386718"
|
||||
|
||||
test_models = random.sample(test_models, 4)
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}])
|
||||
response = completion(model=model, messages=[{'role': 'user', 'content': 'hi'}], temperature=0.7)
|
||||
print("\nModel Response", response)
|
||||
print(response)
|
||||
assert type(response.choices[0].message.content) == str
|
||||
|
@ -94,11 +116,12 @@ def test_vertex_ai_stream():
|
|||
litellm.vertex_project = "hardy-device-386718"
|
||||
import random
|
||||
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = random.sample(test_models, 4)
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
try:
|
||||
if model in ["code-gecko@001", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
print("making request", model)
|
||||
|
@ -115,3 +138,199 @@ def test_vertex_ai_stream():
|
|||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_vertex_ai_stream()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertexai_response():
|
||||
import random
|
||||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
print(f'model being tested in async call: {model}')
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
try:
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = await acompletion(model=model, messages=messages, temperature=0.7, timeout=5)
|
||||
print(f"response: {response}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# asyncio.run(test_async_vertexai_response())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_vertexai_streaming_response():
|
||||
import random
|
||||
load_vertex_ai_credentials()
|
||||
test_models = litellm.vertex_chat_models + litellm.vertex_code_chat_models + litellm.vertex_text_models + litellm.vertex_code_text_models
|
||||
test_models = random.sample(test_models, 1)
|
||||
test_models += litellm.vertex_language_models # always test gemini-pro
|
||||
for model in test_models:
|
||||
if model in ["code-gecko", "code-gecko@001", "code-gecko@002", "code-gecko@latest", "code-bison@001", "text-bison@001"]:
|
||||
# our account does not have access to this model
|
||||
continue
|
||||
try:
|
||||
user_message = "Hello, how are you?"
|
||||
messages = [{"content": user_message, "role": "user"}]
|
||||
response = await acompletion(model="gemini-pro", messages=messages, temperature=0.7, timeout=5, stream=True)
|
||||
print(f"response: {response}")
|
||||
complete_response = ""
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
complete_response += chunk.choices[0].delta.content
|
||||
print(f"complete_response: {complete_response}")
|
||||
assert len(complete_response) > 0
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# asyncio.run(test_async_vertexai_streaming_response())
|
||||
|
||||
def test_gemini_pro_vision():
|
||||
try:
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
litellm.num_retries=0
|
||||
resp = litellm.completion(
|
||||
model = "vertex_ai/gemini-pro-vision",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Whats in this image?"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
)
|
||||
print(resp)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
# test_gemini_pro_vision()
|
||||
|
||||
|
||||
# Extra gemini Vision tests for completion + stream, async, async + stream
|
||||
# if we run into issues with gemini, we will also add these to our ci/cd pipeline
|
||||
# def test_gemini_pro_vision_stream():
|
||||
# try:
|
||||
# litellm.set_verbose = False
|
||||
# litellm.num_retries=0
|
||||
# print("streaming response from gemini-pro-vision")
|
||||
# resp = litellm.completion(
|
||||
# model = "vertex_ai/gemini-pro-vision",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": "Whats in this image?"
|
||||
# },
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ],
|
||||
# stream=True
|
||||
# )
|
||||
# print(resp)
|
||||
# for chunk in resp:
|
||||
# print(chunk)
|
||||
# except Exception as e:
|
||||
# import traceback
|
||||
# traceback.print_exc()
|
||||
# raise e
|
||||
# test_gemini_pro_vision_stream()
|
||||
|
||||
# def test_gemini_pro_vision_async():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# litellm.num_retries=0
|
||||
# async def test():
|
||||
# resp = await litellm.acompletion(
|
||||
# model = "vertex_ai/gemini-pro-vision",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": "Whats in this image?"
|
||||
# },
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ],
|
||||
# )
|
||||
# print("async response gemini pro vision")
|
||||
# print(resp)
|
||||
# asyncio.run(test())
|
||||
# except Exception as e:
|
||||
# import traceback
|
||||
# traceback.print_exc()
|
||||
# raise e
|
||||
# test_gemini_pro_vision_async()
|
||||
|
||||
|
||||
# def test_gemini_pro_vision_async_stream():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# litellm.num_retries=0
|
||||
# async def test():
|
||||
# resp = await litellm.acompletion(
|
||||
# model = "vertex_ai/gemini-pro-vision",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": "Whats in this image?"
|
||||
# },
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ],
|
||||
# stream=True
|
||||
# )
|
||||
# print("async response gemini pro vision")
|
||||
# print(resp)
|
||||
# for chunk in resp:
|
||||
# print(chunk)
|
||||
# asyncio.run(test())
|
||||
# except Exception as e:
|
||||
# import traceback
|
||||
# traceback.print_exc()
|
||||
# raise e
|
||||
# test_gemini_pro_vision_async()
|
|
@ -29,16 +29,19 @@ def generate_random_word(length=4):
|
|||
messages = [{"role": "user", "content": "who is ishaan 5222"}]
|
||||
def test_caching_v2(): # test in memory cache
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
litellm.cache = Cache()
|
||||
response1 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
response2 = completion(model="gpt-3.5-turbo", messages=messages, caching=True)
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
litellm.cache = None # disable cache
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if response2['choices'][0]['message']['content'] != response1['choices'][0]['message']['content']:
|
||||
print(f"response1: {response1}")
|
||||
print(f"response2: {response2}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
pytest.fail(f"Error occurred:")
|
||||
except Exception as e:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
@ -58,6 +61,8 @@ def test_caching_with_models_v2():
|
|||
print(f"response2: {response2}")
|
||||
print(f"response3: {response3}")
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||
# if models are different, it should not return cached response
|
||||
print(f"response2: {response2}")
|
||||
|
@ -91,6 +96,8 @@ def test_embedding_caching():
|
|||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||
print(f"embedding1: {embedding1}")
|
||||
|
@ -145,6 +152,8 @@ def test_embedding_caching_azure():
|
|||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
assert end_time - start_time <= 0.1 # ensure 2nd response comes in in under 0.1 s
|
||||
if embedding2['data'][0]['embedding'] != embedding1['data'][0]['embedding']:
|
||||
print(f"embedding1: {embedding1}")
|
||||
|
@ -175,6 +184,8 @@ def test_redis_cache_completion():
|
|||
print("\nresponse 3", response3)
|
||||
print("\nresponse 4", response4)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
"""
|
||||
1 & 2 should be exactly the same
|
||||
|
@ -226,6 +237,8 @@ def test_redis_cache_completion_stream():
|
|||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.success_callback = []
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
litellm.success_callback = []
|
||||
|
@ -271,11 +284,53 @@ def test_redis_cache_acompletion_stream():
|
|||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
# test_redis_cache_acompletion_stream()
|
||||
|
||||
def test_redis_cache_acompletion_stream_bedrock():
|
||||
import asyncio
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
random_word = generate_random_word()
|
||||
messages = [{"role": "user", "content": f"write a one sentence poem about: {random_word}"}]
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
print("test for caching, streaming + completion")
|
||||
response_1_content = ""
|
||||
response_2_content = ""
|
||||
|
||||
async def call1():
|
||||
nonlocal response_1_content
|
||||
response1 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
async for chunk in response1:
|
||||
print(chunk)
|
||||
response_1_content += chunk.choices[0].delta.content or ""
|
||||
print(response_1_content)
|
||||
asyncio.run(call1())
|
||||
time.sleep(0.5)
|
||||
print("\n\n Response 1 content: ", response_1_content, "\n\n")
|
||||
|
||||
async def call2():
|
||||
nonlocal response_2_content
|
||||
response2 = await litellm.acompletion(model="bedrock/anthropic.claude-v1", messages=messages, max_tokens=40, temperature=1, stream=True)
|
||||
async for chunk in response2:
|
||||
print(chunk)
|
||||
response_2_content += chunk.choices[0].delta.content or ""
|
||||
print(response_2_content)
|
||||
asyncio.run(call2())
|
||||
print("\nresponse 1", response_1_content)
|
||||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise e
|
||||
# test_redis_cache_acompletion_stream_bedrock()
|
||||
# redis cache with custom keys
|
||||
def custom_get_cache_key(*args, **kwargs):
|
||||
# return key to use for your cache:
|
||||
|
@ -312,9 +367,44 @@ def test_custom_redis_cache_with_key():
|
|||
if response3['choices'][0]['message']['content'] == response2['choices'][0]['message']['content']:
|
||||
pytest.fail(f"Error occurred:")
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
|
||||
# test_custom_redis_cache_with_key()
|
||||
|
||||
def test_cache_override():
|
||||
# test if we can override the cache, when `caching=False` but litellm.cache = Cache() is set
|
||||
# in this case it should not return cached responses
|
||||
litellm.cache = Cache()
|
||||
print("Testing cache override")
|
||||
litellm.set_verbose=True
|
||||
|
||||
# test embedding
|
||||
response1 = embedding(
|
||||
model = "text-embedding-ada-002",
|
||||
input=[
|
||||
"hello who are you"
|
||||
],
|
||||
caching = False
|
||||
)
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
response2 = embedding(
|
||||
model = "text-embedding-ada-002",
|
||||
input=[
|
||||
"hello who are you"
|
||||
],
|
||||
caching = False
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"Embedding 2 response time: {end_time - start_time} seconds")
|
||||
|
||||
assert end_time - start_time > 0.1 # ensure 2nd response comes in over 0.1s. This should not be cached.
|
||||
# test_cache_override()
|
||||
|
||||
|
||||
def test_custom_redis_cache_params():
|
||||
# test if we can init redis with **kwargs
|
||||
|
@ -333,6 +423,8 @@ def test_custom_redis_cache_params():
|
|||
|
||||
print(litellm.cache.cache.redis_client)
|
||||
litellm.cache = None
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred:", e)
|
||||
|
||||
|
@ -340,15 +432,58 @@ def test_custom_redis_cache_params():
|
|||
def test_get_cache_key():
|
||||
from litellm.caching import Cache
|
||||
try:
|
||||
print("Testing get_cache_key")
|
||||
cache_instance = Cache()
|
||||
cache_key = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||
)
|
||||
cache_key_2 = cache_instance.get_cache_key(**{'model': 'gpt-3.5-turbo', 'messages': [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}], 'max_tokens': 40, 'temperature': 0.2, 'stream': True, 'litellm_call_id': 'ffe75e7e-8a07-431f-9a74-71a5b9f35f0b', 'litellm_logging_obj': {}}
|
||||
)
|
||||
assert cache_key == "model: gpt-3.5-turbomessages: [{'role': 'user', 'content': 'write a one sentence poem about: 7510'}]temperature: 0.2max_tokens: 40"
|
||||
assert cache_key == cache_key_2, f"{cache_key} != {cache_key_2}. The same kwargs should have the same cache key across runs"
|
||||
|
||||
embedding_cache_key = cache_instance.get_cache_key(
|
||||
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
|
||||
'api_key': '', 'api_version': '2023-07-01-preview',
|
||||
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
|
||||
'caching': True,
|
||||
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>"
|
||||
}
|
||||
)
|
||||
|
||||
print(embedding_cache_key)
|
||||
|
||||
assert embedding_cache_key == "model: azure/azure-embedding-modelinput: ['hi who is ishaan']", f"{embedding_cache_key} != 'model: azure/azure-embedding-modelinput: ['hi who is ishaan']'. The same kwargs should have the same cache key across runs"
|
||||
|
||||
# Proxy - embedding cache, test if embedding key, gets model_group and not model
|
||||
embedding_cache_key_2 = cache_instance.get_cache_key(
|
||||
**{'model': 'azure/azure-embedding-model', 'api_base': 'https://openai-gpt-4-test-v-1.openai.azure.com/',
|
||||
'api_key': '', 'api_version': '2023-07-01-preview',
|
||||
'timeout': None, 'max_retries': 0, 'input': ['hi who is ishaan'],
|
||||
'caching': True,
|
||||
'client': "<openai.lib.azure.AsyncAzureOpenAI object at 0x12b6a1060>",
|
||||
'proxy_server_request': {'url': 'http://0.0.0.0:8000/embeddings',
|
||||
'method': 'POST',
|
||||
'headers':
|
||||
{'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json',
|
||||
'content-length': '80'},
|
||||
'body': {'model': 'azure-embedding-model', 'input': ['hi who is ishaan']}},
|
||||
'user': None,
|
||||
'metadata': {'user_api_key': None,
|
||||
'headers': {'host': '0.0.0.0:8000', 'user-agent': 'curl/7.88.1', 'accept': '*/*', 'content-type': 'application/json', 'content-length': '80'},
|
||||
'model_group': 'EMBEDDING_MODEL_GROUP',
|
||||
'deployment': 'azure/azure-embedding-model-ModelID-azure/azure-embedding-modelhttps://openai-gpt-4-test-v-1.openai.azure.com/2023-07-01-preview'},
|
||||
'model_info': {'mode': 'embedding', 'base_model': 'text-embedding-ada-002', 'id': '20b2b515-f151-4dd5-a74f-2231e2f54e29'},
|
||||
'litellm_call_id': '2642e009-b3cd-443d-b5dd-bb7d56123b0e', 'litellm_logging_obj': '<litellm.utils.Logging object at 0x12f1bddb0>'}
|
||||
)
|
||||
|
||||
print(embedding_cache_key_2)
|
||||
assert embedding_cache_key_2 == "model: EMBEDDING_MODEL_GROUPinput: ['hi who is ishaan']"
|
||||
print("passed!")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred:", e)
|
||||
|
||||
# test_get_cache_key()
|
||||
test_get_cache_key()
|
||||
|
||||
# test_custom_redis_cache_params()
|
||||
|
||||
|
|
|
@ -21,6 +21,13 @@ messages = [{"content": user_message, "role": "user"}]
|
|||
def logger_fn(user_model_dict):
|
||||
print(f"user_model_dict: {user_model_dict}")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_callbacks():
|
||||
print("\npytest fixture - resetting callbacks")
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.failure_callback = []
|
||||
litellm.callbacks = []
|
||||
|
||||
def test_completion_custom_provider_model_name():
|
||||
try:
|
||||
|
@ -54,13 +61,32 @@ def test_completion_claude():
|
|||
print(response)
|
||||
print(response.usage)
|
||||
print(response.usage.completion_tokens)
|
||||
print(response["usage"]["completion_tokens"])
|
||||
print(response["usage"]["completion_tokens"])
|
||||
# print("new cost tracking")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_claude()
|
||||
|
||||
def test_completion_mistral_api():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
response = completion(
|
||||
model="mistral/mistral-tiny",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Hey, how's it going?",
|
||||
}
|
||||
],
|
||||
safe_mode = True
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_mistral_api()
|
||||
|
||||
def test_completion_claude2_1():
|
||||
try:
|
||||
print("claude2.1 test request")
|
||||
|
@ -287,7 +313,7 @@ def hf_test_completion_tgi():
|
|||
print(response)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
hf_test_completion_tgi()
|
||||
# hf_test_completion_tgi()
|
||||
|
||||
# ################### Hugging Face Conversational models ########################
|
||||
# def hf_test_completion_conv():
|
||||
|
@ -611,7 +637,7 @@ def test_completion_azure_key_completion_arg():
|
|||
os.environ.pop("AZURE_API_KEY", None)
|
||||
try:
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
response = completion(
|
||||
model="azure/chatgpt-v-2",
|
||||
|
@ -696,6 +722,7 @@ def test_completion_azure():
|
|||
print(response)
|
||||
|
||||
cost = completion_cost(completion_response=response)
|
||||
assert cost > 0.0
|
||||
print("Cost for azure completion request", cost)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
@ -1013,15 +1040,56 @@ def test_completion_together_ai():
|
|||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
cost = completion_cost(completion_response=response)
|
||||
assert cost > 0.0
|
||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
def test_completion_together_ai_mixtral():
|
||||
model_name = "together_ai/DiscoResearch/DiscoLM-mixtral-8x7b-v2"
|
||||
try:
|
||||
messages =[
|
||||
{"role": "user", "content": "Who are you"},
|
||||
{"role": "assistant", "content": "I am your helpful assistant."},
|
||||
{"role": "user", "content": "Tell me a joke"},
|
||||
]
|
||||
response = completion(model=model_name, messages=messages, max_tokens=256, n=1, logger_fn=logger_fn)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
cost = completion_cost(completion_response=response)
|
||||
assert cost > 0.0
|
||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_completion_together_ai_mixtral()
|
||||
|
||||
def test_completion_together_ai_yi_chat():
|
||||
model_name = "together_ai/zero-one-ai/Yi-34B-Chat"
|
||||
try:
|
||||
messages =[
|
||||
{"role": "user", "content": "What llm are you?"},
|
||||
]
|
||||
response = completion(model=model_name, messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(response)
|
||||
cost = completion_cost(completion_response=response)
|
||||
assert cost > 0.0
|
||||
print("Cost for completion call together-computer/llama-2-70b: ", f"${float(cost):.10f}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_together_ai_yi_chat()
|
||||
|
||||
# test_completion_together_ai()
|
||||
def test_customprompt_together_ai():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
litellm.num_retries = 0
|
||||
print("in test_customprompt_together_ai")
|
||||
print(litellm.success_callback)
|
||||
print(litellm._async_success_callback)
|
||||
response = completion(
|
||||
model="together_ai/mistralai/Mistral-7B-Instruct-v0.1",
|
||||
messages=messages,
|
||||
|
@ -1030,7 +1098,6 @@ def test_customprompt_together_ai():
|
|||
print(response)
|
||||
except litellm.exceptions.Timeout as e:
|
||||
print(f"Timeout Error")
|
||||
litellm.num_retries = 3 # reset retries
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"ERROR TYPE {type(e)}")
|
||||
|
@ -1065,7 +1132,7 @@ def test_completion_chat_sagemaker():
|
|||
temperature=0.7,
|
||||
stream=True,
|
||||
)
|
||||
# Add any assertions here to check the response
|
||||
# Add any assertions here to check the response
|
||||
complete_response = ""
|
||||
for chunk in response:
|
||||
complete_response += chunk.choices[0].delta.content or ""
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_config_context_moderation():
|
|||
print(f"Exception: {e}")
|
||||
pytest.fail(f"An exception occurred: {e}")
|
||||
|
||||
# test_config_context_moderation()
|
||||
# test_config_context_moderation()
|
||||
|
||||
def test_config_context_default_fallback():
|
||||
try:
|
||||
|
|
|
@ -2,7 +2,7 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
import inspect
|
||||
import litellm
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
class testCustomCallbackProxy(CustomLogger):
|
||||
def __init__(self):
|
||||
self.success: bool = False # type: ignore
|
||||
self.failure: bool = False # type: ignore
|
||||
|
@ -55,8 +55,11 @@ class MyCustomHandler(CustomLogger):
|
|||
self.async_success = True
|
||||
print("Value of async success: ", self.async_success)
|
||||
print("\n kwargs: ", kwargs)
|
||||
if kwargs.get("model") == "azure-embedding-model":
|
||||
if kwargs.get("model") == "azure-embedding-model" or kwargs.get("model") == "ada":
|
||||
print("Got an embedding model", kwargs.get("model"))
|
||||
print("Setting embedding success to True")
|
||||
self.async_success_embedding = True
|
||||
print("Value of async success embedding: ", self.async_success_embedding)
|
||||
self.async_embedding_kwargs = kwargs
|
||||
self.async_embedding_response = response_obj
|
||||
if kwargs.get("stream") == True:
|
||||
|
@ -79,6 +82,9 @@ class MyCustomHandler(CustomLogger):
|
|||
# tokens used in response
|
||||
usage = response_obj["usage"]
|
||||
|
||||
print("\n\n in custom callback vars my custom logger, ", vars(my_custom_logger))
|
||||
|
||||
|
||||
print(
|
||||
f"""
|
||||
Model: {model},
|
||||
|
@ -104,4 +110,4 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
self.async_completion_kwargs_fail = kwargs
|
||||
|
||||
my_custom_logger = MyCustomHandler()
|
||||
my_custom_logger = testCustomCallbackProxy()
|
16
litellm/tests/test_configs/test_bad_config.yaml
Normal file
16
litellm/tests/test_configs/test_bad_config.yaml
Normal file
|
@ -0,0 +1,16 @@
|
|||
model_list:
|
||||
- model_name: gpt-3.5-turbo
|
||||
litellm_params:
|
||||
api_key: bad-key
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: azure-gpt-3.5-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: bad-key
|
||||
- model_name: azure-embedding
|
||||
litellm_params:
|
||||
model: azure/azure-embedding-model
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: bad-key
|
||||
|
|
@ -19,3 +19,63 @@ model_list:
|
|||
model_info:
|
||||
description: this is a test openai model
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 56f1bd94-3b54-4b67-9ea2-7c70e9a3a709
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 4d1ee26c-abca-450c-8744-8e87fd6755e9
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 00e19c0f-b63d-42bb-88e9-016fb0c60764
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 79fc75bf-8e1b-47d5-8d24-9365a854af03
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
api_base: os.environ/AZURE_API_BASE
|
||||
api_key: os.environ/AZURE_API_KEY
|
||||
api_version: 2023-07-01-preview
|
||||
model: azure/azure-embedding-model
|
||||
model_name: azure-embedding-model
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 55848c55-4162-40f9-a6e2-9a722b9ef404
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 34339b1e-e030-4bcc-a531-c48559f10ce4
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: f6f74e14-ac64-4403-9365-319e584dcdc5
|
||||
model_name: test_openai_models
|
||||
- litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
model_info:
|
||||
description: this is a test openai model
|
||||
id: 9b1ef341-322c-410a-8992-903987fef439
|
||||
model_name: test_openai_models
|
||||
- model_name: amazon-embeddings
|
||||
litellm_params:
|
||||
model: "bedrock/amazon.titan-embed-text-v1"
|
||||
- model_name: "GPT-J 6B - Sagemaker Text Embedding (Internal)"
|
||||
litellm_params:
|
||||
model: "sagemaker/berri-benchmarking-gpt-j-6b-fp16"
|
631
litellm/tests/test_custom_callback_input.py
Normal file
631
litellm/tests/test_custom_callback_input.py
Normal file
|
@ -0,0 +1,631 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional, Literal, List, Union
|
||||
from litellm import completion, embedding, Cache
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
## 1: Pre-API-Call
|
||||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
## 5. Caching
|
||||
|
||||
# Test models
|
||||
## 1. OpenAI
|
||||
## 2. Azure OpenAI
|
||||
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
||||
|
||||
# Test interfaces
|
||||
## 1. litellm.completion() + litellm.embeddings()
|
||||
## refer to test_custom_callback_input_router.py for the router + proxy tests
|
||||
|
||||
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("async_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, str, dict))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
|
||||
# COMPLETION
|
||||
## Test OpenAI + sync
|
||||
def test_chat_openai_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync openai"
|
||||
}])
|
||||
## test streaming
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_openai_stream()
|
||||
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
## test streaming
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_openai_stream())
|
||||
|
||||
## Test Azure + sync
|
||||
def test_chat_azure_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}])
|
||||
# test streaming
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
# test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_azure_stream()
|
||||
|
||||
## Test Azure + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}])
|
||||
## test streaming
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_azure_stream())
|
||||
|
||||
## Test Bedrock + sync
|
||||
def test_chat_bedrock_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync bedrock"
|
||||
}])
|
||||
# test streaming
|
||||
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync bedrock"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
# test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync bedrock"
|
||||
}],
|
||||
aws_region_name="my-bad-region",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_bedrock_stream()
|
||||
|
||||
## Test Bedrock + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_bedrock_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async bedrock"
|
||||
}])
|
||||
# test streaming
|
||||
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async bedrock"
|
||||
}],
|
||||
stream=True)
|
||||
print(f"response: {response}")
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="bedrock/anthropic.claude-v1",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async bedrock"
|
||||
}],
|
||||
aws_region_name="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_bedrock_stream())
|
||||
|
||||
# EMBEDDING
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_openai():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"])
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="text-embedding-ada-002",
|
||||
input=["good morning from litellm"],
|
||||
api_key="my-bad-key")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_openai())
|
||||
|
||||
## Test Azure + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"])
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"],
|
||||
api_key="my-bad-key")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_azure())
|
||||
|
||||
## Test Bedrock + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_bedrock():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"],
|
||||
aws_region_name="my-bad-region")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_bedrock())
|
||||
|
||||
# CACHING
|
||||
## Test Azure - completion, embedding
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion_azure_caching():
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
response1 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure_caching():
|
||||
print("Testing custom callback input - Azure Caching")
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
response1 = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=[f"good morning from litellm1 {unique_time}"],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # set cache is async for aembedding()
|
||||
response2 = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=[f"good morning from litellm1 {unique_time}"],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(customHandler_caching.states)
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
# asyncio.run(
|
||||
# test_async_embedding_azure_caching()
|
||||
# )
|
488
litellm/tests/test_custom_callback_router.py
Normal file
488
litellm/tests/test_custom_callback_router.py
Normal file
|
@ -0,0 +1,488 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional, Literal, List
|
||||
from litellm import Router, Cache
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
## 1: Pre-API-Call
|
||||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
## fallbacks
|
||||
## retries
|
||||
|
||||
# Test cases
|
||||
## 1. Simple Azure OpenAI acompletion + streaming call
|
||||
## 2. Simple Azure OpenAI aembedding call
|
||||
## 3. Azure OpenAI acompletion + streaming call with retries
|
||||
## 4. Azure OpenAI aembedding call with retries
|
||||
## 5. Azure OpenAI acompletion + streaming call with fallbacks
|
||||
## 6. Azure OpenAI aembedding call with fallbacks
|
||||
|
||||
# Test interfaces
|
||||
## 1. router.completion() + router.embeddings()
|
||||
## 2. proxy.completions + proxy.embeddings
|
||||
|
||||
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
print(f'received kwargs in pre-input: {kwargs}')
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
"""
|
||||
No-op.
|
||||
Not implemented yet.
|
||||
"""
|
||||
pass
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
### ROUTER-SPECIFIC KWARGS
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
|
||||
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], (str, type(None)))
|
||||
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(f"received original response: {kwargs['original_response']}")
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], (datetime, type(None)))
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], (str, type(None)))
|
||||
assert isinstance(kwargs['input'], (list, str, dict))
|
||||
assert isinstance(kwargs['api_key'], (str, type(None)))
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
|
||||
assert isinstance(kwargs['additional_args'], (dict, type(None)))
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
# Simple Azure OpenAI call
|
||||
## COMPLETION
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure():
|
||||
try:
|
||||
customHandler_completion_azure_router = CompletionCustomHandler()
|
||||
customHandler_streaming_azure_router = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_completion_azure_router]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
await asyncio.sleep(2)
|
||||
assert len(customHandler_completion_azure_router.errors) == 0
|
||||
assert len(customHandler_completion_azure_router.states) == 3 # pre, post, success
|
||||
# streaming
|
||||
litellm.callbacks = [customHandler_streaming_azure_router]
|
||||
router2 = Router(model_list=model_list) # type: ignore
|
||||
response = await router2.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
print(f"async azure router chunk: {chunk}")
|
||||
continue
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_streaming_azure_router.states}")
|
||||
assert len(customHandler_streaming_azure_router.errors) == 0
|
||||
assert len(customHandler_streaming_azure_router.states) >= 4 # pre, post, stream (multiple times), success
|
||||
# failure
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
print(f"response in router3 acompletion: {response}")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert "async_failure" in customHandler_failure.states
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_chat_azure())
|
||||
## EMBEDDING
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response = await router.aembedding(model="azure-embedding-model",
|
||||
input=["hello from litellm!"])
|
||||
await asyncio.sleep(2)
|
||||
assert len(customHandler.errors) == 0
|
||||
assert len(customHandler.states) == 3 # pre, post, success
|
||||
# failure
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure-embedding-model", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/azure-embedding-model",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
]
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
router3 = Router(model_list=model_list) # type: ignore
|
||||
try:
|
||||
response = await router3.aembedding(model="azure-embedding-model",
|
||||
input=["hello from litellm!"])
|
||||
print(f"response in router3 aembedding: {response}")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, failure
|
||||
assert "async_failure" in customHandler_failure.states
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_embedding_azure())
|
||||
# Azure OpenAI call w/ Fallbacks
|
||||
## COMPLETION
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure_with_fallbacks():
|
||||
try:
|
||||
customHandler_fallbacks = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_fallbacks]
|
||||
# with fallbacks
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "my-bad-key",
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list, fallbacks=[{"gpt-3.5-turbo": ["gpt-3.5-turbo-16k"]}]) # type: ignore
|
||||
response = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
await asyncio.sleep(2)
|
||||
print(f"customHandler_fallbacks.states: {customHandler_fallbacks.states}")
|
||||
assert len(customHandler_fallbacks.errors) == 0
|
||||
assert len(customHandler_fallbacks.states) == 6 # pre, post, failure, pre, post, success
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_chat_azure_with_fallbacks())
|
||||
|
||||
# CACHING
|
||||
## Test Azure - completion, embedding
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_completion_azure_caching():
|
||||
customHandler_caching = CompletionCustomHandler()
|
||||
litellm.cache = Cache(type="redis", host=os.environ['REDIS_HOST'], port=os.environ['REDIS_PORT'], password=os.environ['REDIS_PASSWORD'])
|
||||
litellm.callbacks = [customHandler_caching]
|
||||
unique_time = time.time()
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION"),
|
||||
"api_base": os.getenv("AZURE_API_BASE")
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-16k",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-16k",
|
||||
},
|
||||
"tpm": 240000,
|
||||
"rpm": 1800
|
||||
}
|
||||
]
|
||||
router = Router(model_list=model_list) # type: ignore
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": f"Hi 👋 - i'm async azure {unique_time}"
|
||||
}],
|
||||
caching=True)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
print(f"customHandler_caching.states post-cache hit: {customHandler_caching.states}")
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
|
@ -1,15 +1,14 @@
|
|||
### What this tests ####
|
||||
import sys, os, time, inspect, asyncio
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
from litellm import completion, embedding
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
async_success = False
|
||||
complete_streaming_response_in_callback = ""
|
||||
|
||||
class MyCustomHandler(CustomLogger):
|
||||
complete_streaming_response_in_callback = ""
|
||||
def __init__(self):
|
||||
self.success: bool = False # type: ignore
|
||||
self.failure: bool = False # type: ignore
|
||||
|
@ -27,9 +26,12 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
self.stream_collected_response = None # type: ignore
|
||||
self.sync_stream_collected_response = None # type: ignore
|
||||
self.user = None # type: ignore
|
||||
self.data_sent_to_api: dict = {}
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
print(f"Pre-API Call")
|
||||
self.data_sent_to_api = kwargs["additional_args"].get("complete_input_dict", {})
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"Post-API Call")
|
||||
|
@ -50,9 +52,8 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async success")
|
||||
print(f"received kwargs user: {kwargs['user']}")
|
||||
self.async_success = True
|
||||
print("Value of async success: ", self.async_success)
|
||||
print("\n kwargs: ", kwargs)
|
||||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_success_embedding = True
|
||||
self.async_embedding_kwargs = kwargs
|
||||
|
@ -60,31 +61,32 @@ class MyCustomHandler(CustomLogger):
|
|||
if kwargs.get("stream") == True:
|
||||
self.stream_collected_response = response_obj
|
||||
self.async_completion_kwargs = kwargs
|
||||
self.user = kwargs.get("user", None)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(f"On Async Failure")
|
||||
self.async_failure = True
|
||||
print("Value of async failure: ", self.async_failure)
|
||||
print("\n kwargs: ", kwargs)
|
||||
if kwargs.get("model") == "text-embedding-ada-002":
|
||||
self.async_failure_embedding = True
|
||||
self.async_embedding_kwargs_fail = kwargs
|
||||
|
||||
self.async_completion_kwargs_fail = kwargs
|
||||
|
||||
async def async_test_logging_fn(kwargs, completion_obj, start_time, end_time):
|
||||
global async_success, complete_streaming_response_in_callback
|
||||
print(f"ON ASYNC LOGGING")
|
||||
async_success = True
|
||||
print("\nKWARGS", kwargs)
|
||||
complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
|
||||
class TmpFunction:
|
||||
complete_streaming_response_in_callback = ""
|
||||
async_success: bool = False
|
||||
async def async_test_logging_fn(self, kwargs, completion_obj, start_time, end_time):
|
||||
print(f"ON ASYNC LOGGING")
|
||||
self.async_success = True
|
||||
print(f'kwargs.get("complete_streaming_response"): {kwargs.get("complete_streaming_response")}')
|
||||
self.complete_streaming_response_in_callback = kwargs.get("complete_streaming_response")
|
||||
|
||||
|
||||
def test_async_chat_openai_stream():
|
||||
try:
|
||||
global complete_streaming_response_in_callback
|
||||
tmp_function = TmpFunction()
|
||||
# litellm.set_verbose = True
|
||||
litellm.success_callback = [async_test_logging_fn]
|
||||
litellm.success_callback = [tmp_function.async_test_logging_fn]
|
||||
complete_streaming_response = ""
|
||||
async def call_gpt():
|
||||
nonlocal complete_streaming_response
|
||||
|
@ -98,12 +100,16 @@ def test_async_chat_openai_stream():
|
|||
complete_streaming_response += chunk["choices"][0]["delta"]["content"] or ""
|
||||
print(complete_streaming_response)
|
||||
asyncio.run(call_gpt())
|
||||
assert complete_streaming_response_in_callback["choices"][0]["message"]["content"] == complete_streaming_response
|
||||
assert async_success == True
|
||||
complete_streaming_response = complete_streaming_response.strip("'")
|
||||
response1 = tmp_function.complete_streaming_response_in_callback["choices"][0]["message"]["content"]
|
||||
response2 = complete_streaming_response
|
||||
# assert [ord(c) for c in response1] == [ord(c) for c in response2]
|
||||
assert response1 == response2
|
||||
assert tmp_function.async_success == True
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An error occurred - {str(e)}")
|
||||
test_async_chat_openai_stream()
|
||||
# test_async_chat_openai_stream()
|
||||
|
||||
def test_completion_azure_stream_moderation_failure():
|
||||
try:
|
||||
|
@ -205,13 +211,27 @@ def test_azure_completion_stream():
|
|||
assert response_in_success_handler == complete_streaming_response
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_azure_completion_stream()
|
||||
|
||||
def test_async_custom_handler():
|
||||
try:
|
||||
customHandler2 = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler2]
|
||||
litellm.set_verbose = True
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_completion():
|
||||
try:
|
||||
customHandler_success = MyCustomHandler()
|
||||
customHandler_failure = MyCustomHandler()
|
||||
# success
|
||||
assert customHandler_success.async_success == False
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "hello from litellm test",
|
||||
}]
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
assert customHandler_success.async_success == True, "async success is not set to True even after success"
|
||||
assert customHandler_success.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
||||
# failure
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
|
@ -219,77 +239,101 @@ def test_async_custom_handler():
|
|||
"content": "how do i kill someone",
|
||||
},
|
||||
]
|
||||
async def test_1():
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
api_key="test",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
assert customHandler2.async_failure == False
|
||||
asyncio.run(test_1())
|
||||
assert customHandler2.async_failure == True, "async failure is not set to True even after failure"
|
||||
assert customHandler2.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
|
||||
assert len(str(customHandler2.async_completion_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
print("Passed setting async failure")
|
||||
|
||||
async def test_2():
|
||||
assert customHandler_failure.async_failure == False
|
||||
try:
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "hello from litellm test",
|
||||
}]
|
||||
)
|
||||
print("\n response", response)
|
||||
assert customHandler2.async_success == False
|
||||
asyncio.run(test_2())
|
||||
assert customHandler2.async_success == True, "async success is not set to True even after success"
|
||||
assert customHandler2.async_completion_kwargs.get("model") == "gpt-3.5-turbo"
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
api_key="my-bad-key",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
assert customHandler_failure.async_failure == True, "async failure is not set to True even after failure"
|
||||
assert customHandler_failure.async_completion_kwargs_fail.get("model") == "gpt-3.5-turbo"
|
||||
assert len(str(customHandler_failure.async_completion_kwargs_fail.get("exception"))) > 10 # expect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
litellm.callbacks = []
|
||||
print("Passed setting async failure")
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_custom_handler_completion())
|
||||
|
||||
|
||||
async def test_3():
|
||||
response = await litellm.aembedding(
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_embedding():
|
||||
try:
|
||||
customHandler_embedding = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_embedding]
|
||||
# success
|
||||
assert customHandler_embedding.async_success_embedding == False
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hello world"],
|
||||
)
|
||||
print("\n response", response)
|
||||
assert customHandler2.async_success_embedding == False
|
||||
asyncio.run(test_3())
|
||||
assert customHandler2.async_success_embedding == True, "async_success_embedding is not set to True even after success"
|
||||
assert customHandler2.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
|
||||
assert customHandler2.async_embedding_response["usage"]["prompt_tokens"] ==2
|
||||
await asyncio.sleep(1)
|
||||
assert customHandler_embedding.async_success_embedding == True, "async_success_embedding is not set to True even after success"
|
||||
assert customHandler_embedding.async_embedding_kwargs.get("model") == "text-embedding-ada-002"
|
||||
assert customHandler_embedding.async_embedding_response["usage"]["prompt_tokens"] ==2
|
||||
print("Passed setting async success: Embedding")
|
||||
|
||||
|
||||
print("Testing custom failure callback for embedding")
|
||||
|
||||
async def test_4():
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hello world"],
|
||||
api_key="test",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
assert customHandler2.async_failure_embedding == False
|
||||
asyncio.run(test_4())
|
||||
assert customHandler2.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
|
||||
assert customHandler2.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
|
||||
assert len(str(customHandler2.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
print("Passed setting async failure")
|
||||
|
||||
# failure
|
||||
assert customHandler_embedding.async_failure_embedding == False
|
||||
try:
|
||||
response = await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hello world"],
|
||||
api_key="my-bad-key",
|
||||
)
|
||||
except:
|
||||
pass
|
||||
assert customHandler_embedding.async_failure_embedding == True, "async failure embedding is not set to True even after failure"
|
||||
assert customHandler_embedding.async_embedding_kwargs_fail.get("model") == "text-embedding-ada-002"
|
||||
assert len(str(customHandler_embedding.async_embedding_kwargs_fail.get("exception"))) > 10 # exppect APIError("OpenAIException - Error code: 401 - {'error': {'message': 'Incorrect API key provided: test. You can find your API key at https://platform.openai.com/account/api-keys.', 'type': 'invalid_request_error', 'param': None, 'code': 'invalid_api_key'}}"), 'traceback_exception': 'Traceback (most recent call last):\n File "/Users/ishaanjaffer/Github/litellm/litellm/llms/openai.py", line 269, in acompletion\n response = await openai_aclient.chat.completions.create(**data)\n File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/openai/resources/chat/completions.py", line 119
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_async_custom_handler()
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
# asyncio.run(test_async_custom_handler_embedding())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_embedding_optional_param():
|
||||
"""
|
||||
Tests if the openai optional params for embedding - user + encoding_format,
|
||||
are logged
|
||||
"""
|
||||
customHandler_optional_params = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_optional_params]
|
||||
response = await litellm.aembedding(
|
||||
model="azure/azure-embedding-model",
|
||||
input = ["hello world"],
|
||||
user = "John"
|
||||
)
|
||||
await asyncio.sleep(1) # success callback is async
|
||||
assert customHandler_optional_params.user == "John"
|
||||
assert customHandler_optional_params.user == customHandler_optional_params.data_sent_to_api["user"]
|
||||
|
||||
# asyncio.run(test_async_custom_handler_embedding_optional_param())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_custom_handler_embedding_optional_param_bedrock():
|
||||
"""
|
||||
Tests if the openai optional params for embedding - user + encoding_format,
|
||||
are logged
|
||||
|
||||
but makes sure these are not sent to the non-openai/azure endpoint (raises errors).
|
||||
"""
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = True
|
||||
customHandler_optional_params = MyCustomHandler()
|
||||
litellm.callbacks = [customHandler_optional_params]
|
||||
response = await litellm.aembedding(
|
||||
model="bedrock/amazon.titan-embed-text-v1",
|
||||
input = ["hello world"],
|
||||
user = "John"
|
||||
)
|
||||
await asyncio.sleep(1) # success callback is async
|
||||
assert customHandler_optional_params.user == "John"
|
||||
assert "user" not in customHandler_optional_params.data_sent_to_api
|
||||
|
||||
|
||||
from litellm import Cache
|
||||
def test_redis_cache_completion_stream():
|
||||
from litellm import Cache
|
||||
# Important Test - This tests if we can add to streaming cache, when custom callbacks are set
|
||||
import random
|
||||
try:
|
||||
|
@ -316,13 +360,10 @@ def test_redis_cache_completion_stream():
|
|||
print("\nresponse 2", response_2_content)
|
||||
assert response_1_content == response_2_content, f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||
litellm.success_callback = []
|
||||
litellm._async_success_callback = []
|
||||
litellm.cache = None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
litellm.success_callback = []
|
||||
raise e
|
||||
"""
|
||||
|
||||
1 & 2 should be exactly the same
|
||||
"""
|
||||
# test_redis_cache_completion_stream()
|
120
litellm/tests/test_dynamodb_logs.py
Normal file
120
litellm/tests/test_dynamodb_logs.py
Normal file
|
@ -0,0 +1,120 @@
|
|||
import sys
|
||||
import os
|
||||
import io, asyncio
|
||||
# import logging
|
||||
# logging.basicConfig(level=logging.DEBUG)
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
|
||||
from litellm import completion
|
||||
import litellm
|
||||
litellm.num_retries = 3
|
||||
|
||||
import time, random
|
||||
import pytest
|
||||
|
||||
|
||||
def pre_request():
|
||||
file_name = f"dynamo.log"
|
||||
log_file = open(file_name, "a+")
|
||||
|
||||
# Clear the contents of the file by truncating it
|
||||
log_file.truncate(0)
|
||||
|
||||
# Save the original stdout so that we can restore it later
|
||||
original_stdout = sys.stdout
|
||||
# Redirect stdout to the file
|
||||
sys.stdout = log_file
|
||||
|
||||
return original_stdout, log_file, file_name
|
||||
|
||||
|
||||
import re
|
||||
def verify_log_file(log_file_path):
|
||||
|
||||
with open(log_file_path, 'r') as log_file:
|
||||
log_content = log_file.read()
|
||||
print(f"\nVerifying DynamoDB file = {log_file_path}. File content=", log_content)
|
||||
|
||||
# Define the pattern to search for in the log file
|
||||
pattern = r"Response from DynamoDB:{.*?}"
|
||||
|
||||
# Find all matches in the log content
|
||||
matches = re.findall(pattern, log_content)
|
||||
|
||||
# Print the DynamoDB success log matches
|
||||
print("DynamoDB Success Log Matches:")
|
||||
for match in matches:
|
||||
print(match)
|
||||
|
||||
# Print the total count of lines containing the specified response
|
||||
print(f"Total occurrences of specified response: {len(matches)}")
|
||||
|
||||
# Count the occurrences of successful responses (status code 200 or 201)
|
||||
success_count = sum(1 for match in matches if "'HTTPStatusCode': 200" in match or "'HTTPStatusCode': 201" in match)
|
||||
|
||||
# Print the count of successful responses
|
||||
print(f"Count of successful responses from DynamoDB: {success_count}")
|
||||
assert success_count == 3 # Expect 3 success logs from dynamoDB
|
||||
|
||||
|
||||
def test_dynamo_logging():
|
||||
# all dynamodb requests need to be in one test function
|
||||
# since we are modifying stdout, and pytests runs tests in parallel
|
||||
try:
|
||||
# pre
|
||||
# redirect stdout to log_file
|
||||
|
||||
litellm.success_callback = ["dynamodb"]
|
||||
litellm.dynamodb_table_name = "litellm-logs-1"
|
||||
litellm.set_verbose = True
|
||||
original_stdout, log_file, file_name = pre_request()
|
||||
|
||||
|
||||
print("Testing async dynamoDB logging")
|
||||
async def _test():
|
||||
return await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content":"This is a test"}],
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
user = "ishaan-2"
|
||||
)
|
||||
response = asyncio.run(_test())
|
||||
print(f"response: {response}")
|
||||
|
||||
|
||||
# streaming + async
|
||||
async def _test2():
|
||||
response = await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content":"This is a test"}],
|
||||
max_tokens=10,
|
||||
temperature=0.7,
|
||||
user = "ishaan-2",
|
||||
stream=True
|
||||
)
|
||||
async for chunk in response:
|
||||
pass
|
||||
asyncio.run(_test2())
|
||||
|
||||
# aembedding()
|
||||
async def _test3():
|
||||
return await litellm.aembedding(
|
||||
model="text-embedding-ada-002",
|
||||
input = ["hi"],
|
||||
user = "ishaan-2"
|
||||
)
|
||||
response = asyncio.run(_test3())
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {e}")
|
||||
finally:
|
||||
# post, close log file and verify
|
||||
# Reset stdout to the original value
|
||||
sys.stdout = original_stdout
|
||||
# Close the file
|
||||
log_file.close()
|
||||
verify_log_file(file_name)
|
||||
print("Passed! Testing async dynamoDB logging")
|
||||
|
||||
# test_dynamo_logging_async()
|
|
@ -164,7 +164,7 @@ def test_bedrock_embedding_titan():
|
|||
assert all(isinstance(x, float) for x in response['data'][0]['embedding']), "Expected response to be a list of floats"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_bedrock_embedding_titan()
|
||||
test_bedrock_embedding_titan()
|
||||
|
||||
def test_bedrock_embedding_cohere():
|
||||
try:
|
||||
|
|
|
@ -21,6 +21,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
import pytest
|
||||
litellm.vertex_project = "pathrise-convert-1606954137718"
|
||||
litellm.vertex_location = "us-central1"
|
||||
litellm.num_retries=0
|
||||
|
||||
# litellm.failure_callback = ["sentry"]
|
||||
#### What this tests ####
|
||||
|
@ -38,10 +39,11 @@ models = ["command-nightly"]
|
|||
# Test 1: Context Window Errors
|
||||
@pytest.mark.parametrize("model", models)
|
||||
def test_context_window(model):
|
||||
print("Testing context window error")
|
||||
sample_text = "Say error 50 times" * 1000000
|
||||
messages = [{"content": sample_text, "role": "user"}]
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
litellm.set_verbose = True
|
||||
response = completion(model=model, messages=messages)
|
||||
print(f"response: {response}")
|
||||
print("FAILED!")
|
||||
|
@ -176,7 +178,7 @@ def test_completion_azure_exception():
|
|||
try:
|
||||
import openai
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ["AZURE_API_KEY"] = "good morning"
|
||||
|
@ -189,6 +191,7 @@ def test_completion_azure_exception():
|
|||
}
|
||||
],
|
||||
)
|
||||
os.environ["AZURE_API_KEY"] = old_azure_key
|
||||
print(f"response: {response}")
|
||||
print(response)
|
||||
except openai.AuthenticationError as e:
|
||||
|
@ -196,14 +199,14 @@ def test_completion_azure_exception():
|
|||
print("good job got the correct error for azure when key not set")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
test_completion_azure_exception()
|
||||
# test_completion_azure_exception()
|
||||
|
||||
async def asynctest_completion_azure_exception():
|
||||
try:
|
||||
import openai
|
||||
import litellm
|
||||
print("azure gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["AZURE_API_KEY"]
|
||||
os.environ["AZURE_API_KEY"] = "good morning"
|
||||
|
@ -226,19 +229,75 @@ async def asynctest_completion_azure_exception():
|
|||
print("Got wrong exception")
|
||||
print("exception", e)
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# import asyncio
|
||||
# asyncio.run(
|
||||
# asynctest_completion_azure_exception()
|
||||
# )
|
||||
|
||||
|
||||
def asynctest_completion_openai_exception_bad_model():
|
||||
try:
|
||||
import openai
|
||||
import litellm, asyncio
|
||||
print("azure exception bad model\n\n")
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
async def test():
|
||||
response = await litellm.acompletion(
|
||||
model="openai/gpt-6",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
)
|
||||
asyncio.run(test())
|
||||
except openai.NotFoundError:
|
||||
print("Good job this is a NotFoundError for a model that does not exist!")
|
||||
print("Passed")
|
||||
except Exception as e:
|
||||
print("Raised wrong type of exception", type(e))
|
||||
assert isinstance(e, openai.BadRequestError)
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# asynctest_completion_openai_exception_bad_model()
|
||||
|
||||
|
||||
|
||||
def asynctest_completion_azure_exception_bad_model():
|
||||
try:
|
||||
import openai
|
||||
import litellm, asyncio
|
||||
print("azure exception bad model\n\n")
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
async def test():
|
||||
response = await litellm.acompletion(
|
||||
model="azure/gpt-12",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
)
|
||||
asyncio.run(test())
|
||||
except openai.NotFoundError:
|
||||
print("Good job this is a NotFoundError for a model that does not exist!")
|
||||
print("Passed")
|
||||
except Exception as e:
|
||||
print("Raised wrong type of exception", type(e))
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# asynctest_completion_azure_exception_bad_model()
|
||||
|
||||
def test_completion_openai_exception():
|
||||
# test if openai:gpt raises openai.AuthenticationError
|
||||
try:
|
||||
import openai
|
||||
print("openai gpt-3.5 test\n\n")
|
||||
litellm.set_verbose=False
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["OPENAI_API_KEY"]
|
||||
os.environ["OPENAI_API_KEY"] = "good morning"
|
||||
|
@ -255,11 +314,38 @@ def test_completion_openai_exception():
|
|||
print(response)
|
||||
except openai.AuthenticationError as e:
|
||||
os.environ["OPENAI_API_KEY"] = old_azure_key
|
||||
print("good job got the correct error for openai when key not set")
|
||||
print("OpenAI: good job got the correct error for openai when key not set")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_openai_exception()
|
||||
|
||||
def test_completion_mistral_exception():
|
||||
# test if mistral/mistral-tiny raises openai.AuthenticationError
|
||||
try:
|
||||
import openai
|
||||
print("Testing mistral ai exception mapping")
|
||||
litellm.set_verbose=True
|
||||
## Test azure call
|
||||
old_azure_key = os.environ["MISTRAL_API_KEY"]
|
||||
os.environ["MISTRAL_API_KEY"] = "good morning"
|
||||
response = completion(
|
||||
model="mistral/mistral-tiny",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello"
|
||||
}
|
||||
],
|
||||
)
|
||||
print(f"response: {response}")
|
||||
print(response)
|
||||
except openai.AuthenticationError as e:
|
||||
os.environ["MISTRAL_API_KEY"] = old_azure_key
|
||||
print("good job got the correct error for openai when key not set")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_completion_mistral_exception()
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -301,4 +387,4 @@ def test_completion_openai_exception():
|
|||
# counts[result] += 1
|
||||
|
||||
# accuracy_score = counts[True]/(counts[True] + counts[False])
|
||||
# print(f"accuracy_score: {accuracy_score}")
|
||||
# print(f"accuracy_score: {accuracy_score}")
|
|
@ -9,33 +9,107 @@ from litellm import completion
|
|||
import litellm
|
||||
litellm.num_retries = 3
|
||||
litellm.success_callback = ["langfuse"]
|
||||
# litellm.set_verbose = True
|
||||
os.environ["LANGFUSE_DEBUG"] = "True"
|
||||
import time
|
||||
import pytest
|
||||
|
||||
def search_logs(log_file_path):
|
||||
"""
|
||||
Searches the given log file for logs containing the "/api/public" string.
|
||||
|
||||
Parameters:
|
||||
- log_file_path (str): The path to the log file to be searched.
|
||||
|
||||
Returns:
|
||||
- None
|
||||
|
||||
Raises:
|
||||
- Exception: If there are any bad logs found in the log file.
|
||||
"""
|
||||
import re
|
||||
print("\n searching logs")
|
||||
bad_logs = []
|
||||
good_logs = []
|
||||
all_logs = []
|
||||
try:
|
||||
with open(log_file_path, 'r') as log_file:
|
||||
lines = log_file.readlines()
|
||||
print(f"searching logslines: {lines}")
|
||||
for line in lines:
|
||||
all_logs.append(line.strip())
|
||||
if "/api/public" in line:
|
||||
print("Found log with /api/public:")
|
||||
print(line.strip())
|
||||
print("\n\n")
|
||||
match = re.search(r'receive_response_headers.complete return_value=\(b\'HTTP/1.1\', (\d+),', line)
|
||||
if match:
|
||||
status_code = int(match.group(1))
|
||||
if status_code != 200 and status_code != 201:
|
||||
print("got a BAD log")
|
||||
bad_logs.append(line.strip())
|
||||
else:
|
||||
|
||||
good_logs.append(line.strip())
|
||||
print("\nBad Logs")
|
||||
print(bad_logs)
|
||||
if len(bad_logs)>0:
|
||||
raise Exception(f"bad logs, Bad logs = {bad_logs}")
|
||||
|
||||
print("\nGood Logs")
|
||||
print(good_logs)
|
||||
if len(good_logs) <= 0:
|
||||
raise Exception(f"There were no Good Logs from Langfuse. No logs with /api/public status 200. \nAll logs:{all_logs}")
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def pre_langfuse_setup():
|
||||
"""
|
||||
Set up the logging for the 'pre_langfuse_setup' function.
|
||||
"""
|
||||
# sends logs to langfuse.log
|
||||
import logging
|
||||
# Configure the logging to write to a file
|
||||
logging.basicConfig(filename="langfuse.log", level=logging.DEBUG)
|
||||
logger = logging.getLogger()
|
||||
|
||||
# Add a FileHandler to the logger
|
||||
file_handler = logging.FileHandler("langfuse.log", mode='w')
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
logger.addHandler(file_handler)
|
||||
return
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_async():
|
||||
try:
|
||||
pre_langfuse_setup()
|
||||
litellm.set_verbose = True
|
||||
async def _test_langfuse():
|
||||
return await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content":"This is a test"}],
|
||||
max_tokens=1000,
|
||||
max_tokens=100,
|
||||
temperature=0.7,
|
||||
timeout=5,
|
||||
)
|
||||
response = asyncio.run(_test_langfuse())
|
||||
print(f"response: {response}")
|
||||
|
||||
# time.sleep(2)
|
||||
# # check langfuse.log to see if there was a failed response
|
||||
# search_logs("langfuse.log")
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred - {e}")
|
||||
|
||||
# test_langfuse_logging_async()
|
||||
test_langfuse_logging_async()
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging():
|
||||
try:
|
||||
# litellm.set_verbose = True
|
||||
pre_langfuse_setup()
|
||||
litellm.set_verbose = True
|
||||
response = completion(model="claude-instant-1.2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
|
@ -43,17 +117,20 @@ def test_langfuse_logging():
|
|||
}],
|
||||
max_tokens=10,
|
||||
temperature=0.2,
|
||||
metadata={"langfuse/key": "foo"}
|
||||
)
|
||||
print(response)
|
||||
# time.sleep(5)
|
||||
# # check langfuse.log to see if there was a failed response
|
||||
# search_logs("langfuse.log")
|
||||
|
||||
except litellm.Timeout as e:
|
||||
pass
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pytest.fail(f"An exception occurred - {e}")
|
||||
|
||||
test_langfuse_logging()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_stream():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
|
@ -77,6 +154,7 @@ def test_langfuse_logging_stream():
|
|||
|
||||
# test_langfuse_logging_stream()
|
||||
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_custom_generation_name():
|
||||
try:
|
||||
litellm.set_verbose=True
|
||||
|
@ -99,8 +177,8 @@ def test_langfuse_logging_custom_generation_name():
|
|||
pytest.fail(f"An exception occurred - {e}")
|
||||
print(e)
|
||||
|
||||
test_langfuse_logging_custom_generation_name()
|
||||
|
||||
# test_langfuse_logging_custom_generation_name()
|
||||
@pytest.mark.skip(reason="beta test - checking langfuse output")
|
||||
def test_langfuse_logging_function_calling():
|
||||
function1 = [
|
||||
{
|
||||
|
|
|
@ -17,10 +17,10 @@ model_alias_map = {
|
|||
"good-model": "anyscale/meta-llama/Llama-2-7b-chat-hf"
|
||||
}
|
||||
|
||||
litellm.model_alias_map = model_alias_map
|
||||
|
||||
def test_model_alias_map():
|
||||
try:
|
||||
litellm.model_alias_map = model_alias_map
|
||||
response = completion(
|
||||
"good-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
|
|
|
@ -1,100 +1,37 @@
|
|||
##### THESE TESTS CAN ONLY RUN LOCALLY WITH THE OLLAMA SERVER RUNNING ######
|
||||
# import aiohttp
|
||||
# import json
|
||||
# import asyncio
|
||||
# import requests
|
||||
#
|
||||
# async def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||
# session = aiohttp.ClientSession()
|
||||
# url = f'{api_base}/api/generate'
|
||||
# data = {
|
||||
# "model": model,
|
||||
# "prompt": prompt,
|
||||
# }
|
||||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# response = ""
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
# try:
|
||||
# async with session.post(url, json=data) as resp:
|
||||
# async for line in resp.content.iter_any():
|
||||
# if line:
|
||||
# try:
|
||||
# json_chunk = line.decode("utf-8")
|
||||
# chunks = json_chunk.split("\n")
|
||||
# for chunk in chunks:
|
||||
# if chunk.strip() != "":
|
||||
# j = json.loads(chunk)
|
||||
# if "response" in j:
|
||||
# print(j["response"])
|
||||
# yield {
|
||||
# "role": "assistant",
|
||||
# "content": j["response"]
|
||||
# }
|
||||
# # self.responses.append(j["response"])
|
||||
# # yield "blank"
|
||||
# except Exception as e:
|
||||
# print(f"Error decoding JSON: {e}")
|
||||
# finally:
|
||||
# await session.close()
|
||||
|
||||
# async def get_ollama_response_no_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||
# generator = get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?")
|
||||
# response = ""
|
||||
# async for elem in generator:
|
||||
# print(elem)
|
||||
# response += elem["content"]
|
||||
# return response
|
||||
|
||||
# #generator = get_ollama_response_stream()
|
||||
|
||||
# result = asyncio.run(get_ollama_response_no_stream())
|
||||
# print(result)
|
||||
|
||||
# # return this generator to the client for streaming requests
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
|
||||
|
||||
# async def get_response():
|
||||
# global generator
|
||||
# async for elem in generator:
|
||||
# print(elem)
|
||||
## for ollama we can't test making the completion call
|
||||
from litellm.utils import get_optional_params, get_llm_provider
|
||||
|
||||
# asyncio.run(get_response())
|
||||
def test_get_ollama_params():
|
||||
try:
|
||||
converted_params = get_optional_params(custom_llm_provider="ollama", model="llama2", max_tokens=20, temperature=0.5, stream=True)
|
||||
print("Converted params", converted_params)
|
||||
assert converted_params == {'num_predict': 20, 'stream': True, 'temperature': 0.5}, f"{converted_params} != {'num_predict': 20, 'stream': True, 'temperature': 0.5}"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_get_ollama_params()
|
||||
|
||||
|
||||
def test_get_ollama_model():
|
||||
try:
|
||||
model, custom_llm_provider, _, _ = get_llm_provider("ollama/code-llama-22")
|
||||
print("Model", "custom_llm_provider", model, custom_llm_provider)
|
||||
assert custom_llm_provider == "ollama", f"{custom_llm_provider} != ollama"
|
||||
assert model == "code-llama-22", f"{model} != code-llama-22"
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
##### latest implementation of making raw http post requests to local ollama server
|
||||
|
||||
# import requests
|
||||
# import json
|
||||
# def get_ollama_response_stream(api_base="http://localhost:11434", model="llama2", prompt="Why is the sky blue?"):
|
||||
# url = f"{api_base}/api/generate"
|
||||
# data = {
|
||||
# "model": model,
|
||||
# "prompt": prompt,
|
||||
# }
|
||||
# session = requests.Session()
|
||||
|
||||
# with session.post(url, json=data, stream=True) as resp:
|
||||
# for line in resp.iter_lines():
|
||||
# if line:
|
||||
# try:
|
||||
# json_chunk = line.decode("utf-8")
|
||||
# chunks = json_chunk.split("\n")
|
||||
# for chunk in chunks:
|
||||
# if chunk.strip() != "":
|
||||
# j = json.loads(chunk)
|
||||
# if "response" in j:
|
||||
# completion_obj = {
|
||||
# "role": "assistant",
|
||||
# "content": "",
|
||||
# }
|
||||
# completion_obj["content"] = j["response"]
|
||||
# yield {"choices": [{"delta": completion_obj}]}
|
||||
# except Exception as e:
|
||||
# print(f"Error decoding JSON: {e}")
|
||||
# session.close()
|
||||
|
||||
# response = get_ollama_response_stream()
|
||||
|
||||
# for chunk in response:
|
||||
# print(chunk['choices'][0]['delta'])
|
||||
# test_get_ollama_model()
|
|
@ -16,6 +16,19 @@
|
|||
# user_message = "respond in 20 words. who are you?"
|
||||
# messages = [{ "content": user_message,"role": "user"}]
|
||||
|
||||
# async def test_async_ollama_streaming():
|
||||
# try:
|
||||
# litellm.set_verbose = True
|
||||
# response = await litellm.acompletion(model="ollama/mistral-openorca",
|
||||
# messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
# stream=True)
|
||||
# async for chunk in response:
|
||||
# print(chunk)
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
|
||||
# asyncio.run(test_async_ollama_streaming())
|
||||
|
||||
# def test_completion_ollama():
|
||||
# try:
|
||||
# response = completion(
|
||||
|
@ -29,7 +42,7 @@
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_ollama()
|
||||
# # test_completion_ollama()
|
||||
|
||||
# def test_completion_ollama_with_api_base():
|
||||
# try:
|
||||
|
@ -42,7 +55,7 @@
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_ollama_with_api_base()
|
||||
# # test_completion_ollama_with_api_base()
|
||||
|
||||
|
||||
# def test_completion_ollama_custom_prompt_template():
|
||||
|
@ -72,7 +85,7 @@
|
|||
# traceback.print_exc()
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_ollama_custom_prompt_template()
|
||||
# # test_completion_ollama_custom_prompt_template()
|
||||
|
||||
# async def test_completion_ollama_async_stream():
|
||||
# user_message = "what is the weather"
|
||||
|
@ -98,8 +111,8 @@
|
|||
# except Exception as e:
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# import asyncio
|
||||
# asyncio.run(test_completion_ollama_async_stream())
|
||||
# # import asyncio
|
||||
# # asyncio.run(test_completion_ollama_async_stream())
|
||||
|
||||
|
||||
|
||||
|
@ -154,8 +167,35 @@
|
|||
# pass
|
||||
# pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
# test_completion_expect_error()
|
||||
# # test_completion_expect_error()
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# import asyncio
|
||||
# asyncio.run(main())
|
||||
|
||||
# def test_ollama_llava():
|
||||
# litellm.set_verbose=True
|
||||
# # same params as gpt-4 vision
|
||||
# response = completion(
|
||||
# model = "ollama/llava",
|
||||
# messages=[
|
||||
# {
|
||||
# "role": "user",
|
||||
# "content": [
|
||||
# {
|
||||
# "type": "text",
|
||||
# "text": "What is in this picture"
|
||||
# },
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": "iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
# }
|
||||
# ],
|
||||
# )
|
||||
# print("Response from ollama/llava")
|
||||
# print(response)
|
||||
# test_ollama_llava()
|
||||
|
||||
|
||||
# # PROCESSED CHUNK PRE CHUNK CREATOR
|
||||
|
|
27
litellm/tests/test_optional_params.py
Normal file
27
litellm/tests/test_optional_params.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
#### What this tests ####
|
||||
# This tests if get_optional_params works as expected
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
import litellm
|
||||
from litellm.utils import get_optional_params_embeddings
|
||||
## get_optional_params_embeddings
|
||||
### Models: OpenAI, Azure, Bedrock
|
||||
### Scenarios: w/ optional params + litellm.drop_params = True
|
||||
|
||||
def test_bedrock_optional_params_embeddings():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="bedrock")
|
||||
assert len(optional_params) == 0
|
||||
|
||||
def test_openai_optional_params_embeddings():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="openai")
|
||||
assert len(optional_params) == 1
|
||||
assert optional_params["user"] == "John"
|
||||
|
||||
def test_azure_optional_params_embeddings():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params_embeddings(user="John", encoding_format=None, custom_llm_provider="azure")
|
||||
assert len(optional_params) == 1
|
||||
assert optional_params["user"] == "John"
|
|
@ -19,21 +19,23 @@ from litellm import RateLimitError
|
|||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
@app.on_event("startup")
|
||||
async def wrapper_startup_event():
|
||||
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
@pytest.fixture(scope="function")
|
||||
def client():
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||
cleanup_router_config_variables()
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
app = FastAPI()
|
||||
initialize(config=config_fp)
|
||||
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_custom_auth(client):
|
||||
try:
|
||||
|
|
|
@ -3,7 +3,7 @@ import traceback
|
|||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
import os, io, asyncio
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
|
@ -21,21 +21,24 @@ from fastapi.testclient import TestClient
|
|||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||
python_file_path = f"{filepath}/test_configs/custom_callbacks.py"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
@app.on_event("startup")
|
||||
async def wrapper_startup_event():
|
||||
initialize(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=True, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
# @app.on_event("startup")
|
||||
# async def wrapper_startup_event():
|
||||
# initialize(config=config_fp)
|
||||
|
||||
# Use the app fixture in your client fixture
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_custom_logger.yaml"
|
||||
initialize(config=config_fp)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
@ -45,15 +48,76 @@ headers = {
|
|||
}
|
||||
|
||||
|
||||
def test_chat_completion(client):
|
||||
print("Testing proxy custom logger")
|
||||
|
||||
def test_embedding(client):
|
||||
try:
|
||||
# Your test data
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
value = "custom_callbacks.my_custom_logger",
|
||||
config_file_path=python_file_path
|
||||
)
|
||||
print("id of initialized custom logger", id(my_custom_logger))
|
||||
litellm.callbacks = [my_custom_logger]
|
||||
# Your test data
|
||||
print("initialized proxy")
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
my_custom_logger = litellm.callbacks[0]
|
||||
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
assert my_custom_logger.async_success_embedding == False
|
||||
|
||||
test_data = {
|
||||
"model": "azure-embedding-model",
|
||||
"input": ["hello"]
|
||||
}
|
||||
response = client.post("/embeddings", json=test_data, headers=headers)
|
||||
print("made request", response.status_code, response.text)
|
||||
print("vars my custom logger /embeddings", vars(my_custom_logger), "id", id(my_custom_logger))
|
||||
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
|
||||
kwargs = my_custom_logger.async_embedding_kwargs
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||
assert metadata is not None
|
||||
assert "user_api_key" in metadata
|
||||
assert "headers" in metadata
|
||||
proxy_server_request = litellm_params.get("proxy_server_request")
|
||||
model_info = litellm_params.get("model_info")
|
||||
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
|
||||
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
print("Passed Embedding custom logger on proxy!")
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
def test_chat_completion(client):
|
||||
try:
|
||||
# Your test data
|
||||
|
||||
print("initialized proxy")
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
value = "custom_callbacks.my_custom_logger",
|
||||
config_file_path=python_file_path
|
||||
)
|
||||
|
||||
print("id of initialized custom logger", id(my_custom_logger))
|
||||
|
||||
litellm.callbacks = [my_custom_logger]
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
# assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
|
||||
print("LiteLLM Callbacks", litellm.callbacks)
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
assert my_custom_logger.async_success == False
|
||||
|
||||
test_data = {
|
||||
|
@ -61,7 +125,7 @@ def test_chat_completion(client):
|
|||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
"content": "write a litellm poem"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
|
@ -70,33 +134,53 @@ def test_chat_completion(client):
|
|||
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
print("made request", response.status_code, response.text)
|
||||
print("LiteLLM Callbacks", litellm.callbacks)
|
||||
asyncio.sleep(1) # sleep while waiting for callback to run
|
||||
|
||||
print("my_custom_logger in /chat/completions", my_custom_logger, "id", id(my_custom_logger))
|
||||
print("vars my custom logger, ", vars(my_custom_logger))
|
||||
assert my_custom_logger.async_success == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||
assert my_custom_logger.async_completion_kwargs["model"] == "chatgpt-v-2" # checks if kwargs passed to async_log_success_event are correct
|
||||
print("\n\n Custom Logger Async Completion args", my_custom_logger.async_completion_kwargs)
|
||||
|
||||
litellm_params = my_custom_logger.async_completion_kwargs.get("litellm_params")
|
||||
metadata = litellm_params.get("metadata", None)
|
||||
print("\n\n Metadata in custom logger kwargs", litellm_params.get("metadata"))
|
||||
assert metadata is not None
|
||||
assert "user_api_key" in metadata
|
||||
assert "headers" in metadata
|
||||
config_model_info = litellm_params.get("model_info")
|
||||
proxy_server_request_object = litellm_params.get("proxy_server_request")
|
||||
|
||||
assert config_model_info == {'id': 'gm', 'input_cost_per_token': 0.0002, 'mode': 'chat'}
|
||||
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '105', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'hi'}], 'max_tokens': 10}}
|
||||
assert proxy_server_request_object == {'url': 'http://testserver/chat/completions', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '123', 'content-type': 'application/json'}, 'body': {'model': 'Azure OpenAI GPT-4 Canada', 'messages': [{'role': 'user', 'content': 'write a litellm poem'}], 'max_tokens': 10}}
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
print("\nPassed /chat/completions with Custom Logger!")
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
def test_chat_completion_stream(client):
|
||||
try:
|
||||
# Your test data
|
||||
litellm.set_verbose=False
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
my_custom_logger = get_instance_fn(
|
||||
value = "custom_callbacks.my_custom_logger",
|
||||
config_file_path=python_file_path
|
||||
)
|
||||
|
||||
print("id of initialized custom logger", id(my_custom_logger))
|
||||
|
||||
litellm.callbacks = [my_custom_logger]
|
||||
import json
|
||||
print("initialized proxy")
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
|
||||
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
my_custom_logger = litellm.callbacks[0]
|
||||
print("LiteLLM Callbacks", litellm.callbacks)
|
||||
print("my_custom_logger", my_custom_logger)
|
||||
|
||||
assert my_custom_logger.streaming_response_obj == None # no streaming response obj is set pre call
|
||||
|
||||
|
@ -148,37 +232,5 @@ def test_chat_completion_stream(client):
|
|||
assert complete_response == streamed_response["choices"][0]["message"]["content"]
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
|
||||
def test_embedding(client):
|
||||
try:
|
||||
# Your test data
|
||||
print("initialized proxy")
|
||||
# import the initialized custom logger
|
||||
print(litellm.callbacks)
|
||||
|
||||
assert len(litellm.callbacks) == 1 # assert litellm is initialized with 1 callback
|
||||
my_custom_logger = litellm.callbacks[0]
|
||||
assert my_custom_logger.async_success_embedding == False
|
||||
|
||||
test_data = {
|
||||
"model": "azure-embedding-model",
|
||||
"input": ["hello"]
|
||||
}
|
||||
response = client.post("/embeddings", json=test_data, headers=headers)
|
||||
print("made request", response.status_code, response.text)
|
||||
assert my_custom_logger.async_success_embedding == True # checks if the status of async_success is True, only the async_log_success_event can set this to true
|
||||
assert my_custom_logger.async_embedding_kwargs["model"] == "azure-embedding-model" # checks if kwargs passed to async_log_success_event are correct
|
||||
|
||||
kwargs = my_custom_logger.async_embedding_kwargs
|
||||
litellm_params = kwargs.get("litellm_params")
|
||||
proxy_server_request = litellm_params.get("proxy_server_request")
|
||||
model_info = litellm_params.get("model_info")
|
||||
assert proxy_server_request == {'url': 'http://testserver/embeddings', 'method': 'POST', 'headers': {'host': 'testserver', 'accept': '*/*', 'accept-encoding': 'gzip, deflate', 'connection': 'keep-alive', 'user-agent': 'testclient', 'authorization': 'Bearer sk-1234', 'content-length': '54', 'content-type': 'application/json'}, 'body': {'model': 'azure-embedding-model', 'input': ['hello']}}
|
||||
assert model_info == {'input_cost_per_token': 0.002, 'mode': 'embedding', 'id': 'hello'}
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
177
litellm/tests/test_proxy_exception_mapping.py
Normal file
177
litellm/tests/test_proxy_exception_mapping.py
Normal file
|
@ -0,0 +1,177 @@
|
|||
# test that the proxy actually does exception mapping to the OpenAI format
|
||||
|
||||
import sys, os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io, asyncio
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm, openai
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_bad_config.yaml"
|
||||
initialize(config=config_fp)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
return TestClient(app)
|
||||
|
||||
# raise openai.AuthenticationError
|
||||
def test_chat_completion_exception(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
response = client.post("/chat/completions", json=test_data)
|
||||
|
||||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
# raise openai.AuthenticationError
|
||||
def test_chat_completion_exception_azure(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "azure-gpt-3.5-turbo",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
response = client.post("/chat/completions", json=test_data)
|
||||
|
||||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
# raise openai.AuthenticationError
|
||||
def test_embedding_auth_exception_azure(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "azure-embedding",
|
||||
"input": ["hi"]
|
||||
}
|
||||
|
||||
response = client.post("/embeddings", json=test_data)
|
||||
print("Response from proxy=", response)
|
||||
|
||||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
print("Exception raised=", openai_exception)
|
||||
assert isinstance(openai_exception, openai.AuthenticationError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
|
||||
|
||||
# raise openai.BadRequestError
|
||||
# chat/completions openai
|
||||
def test_exception_openai_bad_model(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "azure/GPT-12",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
response = client.post("/chat/completions", json=test_data)
|
||||
|
||||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
print("Type of exception=", type(openai_exception))
|
||||
assert isinstance(openai_exception, openai.NotFoundError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
# chat/completions any model
|
||||
def test_chat_completion_exception_any_model(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "Lite-GPT-12",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
response = client.post("/chat/completions", json=test_data)
|
||||
|
||||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
print("Exception raised=", openai_exception)
|
||||
assert isinstance(openai_exception, openai.NotFoundError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
||||
|
||||
# embeddings any model
|
||||
def test_embedding_exception_any_model(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "Lite-GPT-12",
|
||||
"input": ["hi"]
|
||||
}
|
||||
|
||||
response = client.post("/embeddings", json=test_data)
|
||||
print("Response from proxy=", response)
|
||||
|
||||
# make an openai client to call _make_status_error_from_response
|
||||
openai_client = openai.OpenAI(api_key="anything")
|
||||
openai_exception = openai_client._make_status_error_from_response(response=response)
|
||||
print("Exception raised=", openai_exception)
|
||||
assert isinstance(openai_exception, openai.NotFoundError)
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
||||
|
|
@ -24,30 +24,29 @@ logging.basicConfig(
|
|||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, initialize # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
@app.on_event("startup")
|
||||
async def wrapper_startup_event():
|
||||
initialize(config=config_fp)
|
||||
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
token = ""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
def client():
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
@pytest.fixture(scope="function")
|
||||
def client_no_auth():
|
||||
# Assuming litellm.proxy.proxy_server is an object
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||
cleanup_router_config_variables()
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
initialize(config=config_fp)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
def test_chat_completion(client):
|
||||
return TestClient(app)
|
||||
|
||||
def test_chat_completion(client_no_auth):
|
||||
global headers
|
||||
try:
|
||||
# Your test data
|
||||
|
@ -62,8 +61,8 @@ def test_chat_completion(client):
|
|||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print("testing proxy server")
|
||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||
print("testing proxy server with chat completions")
|
||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||
print(f"response - {response.text}")
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
@ -73,7 +72,8 @@ def test_chat_completion(client):
|
|||
|
||||
# Run the test
|
||||
|
||||
def test_chat_completion_azure(client):
|
||||
def test_chat_completion_azure(client_no_auth):
|
||||
|
||||
global headers
|
||||
try:
|
||||
# Your test data
|
||||
|
@ -88,8 +88,8 @@ def test_chat_completion_azure(client):
|
|||
"max_tokens": 10,
|
||||
}
|
||||
|
||||
print("testing proxy server with Azure Request")
|
||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||
print("testing proxy server with Azure Request /chat/completions")
|
||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
@ -102,15 +102,55 @@ def test_chat_completion_azure(client):
|
|||
# test_chat_completion_azure()
|
||||
|
||||
|
||||
def test_embedding(client):
|
||||
def test_embedding(client_no_auth):
|
||||
global headers
|
||||
from litellm.proxy.proxy_server import user_custom_auth
|
||||
|
||||
try:
|
||||
test_data = {
|
||||
"model": "azure/azure-embedding-model",
|
||||
"input": ["good morning from litellm"],
|
||||
}
|
||||
print("testing proxy server with OpenAI embedding")
|
||||
response = client.post("/v1/embeddings", json=test_data, headers=headers)
|
||||
|
||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(len(result["data"][0]["embedding"]))
|
||||
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
def test_bedrock_embedding(client_no_auth):
|
||||
global headers
|
||||
from litellm.proxy.proxy_server import user_custom_auth
|
||||
|
||||
try:
|
||||
test_data = {
|
||||
"model": "amazon-embeddings",
|
||||
"input": ["good morning from litellm"],
|
||||
}
|
||||
|
||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(len(result["data"][0]["embedding"]))
|
||||
assert len(result["data"][0]["embedding"]) > 10 # this usually has len==1536 so
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}")
|
||||
|
||||
def test_sagemaker_embedding(client_no_auth):
|
||||
global headers
|
||||
from litellm.proxy.proxy_server import user_custom_auth
|
||||
|
||||
try:
|
||||
test_data = {
|
||||
"model": "GPT-J 6B - Sagemaker Text Embedding (Internal)",
|
||||
"input": ["good morning from litellm"],
|
||||
}
|
||||
|
||||
response = client_no_auth.post("/v1/embeddings", json=test_data)
|
||||
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
|
@ -122,8 +162,8 @@ def test_embedding(client):
|
|||
# Run the test
|
||||
# test_embedding()
|
||||
|
||||
@pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
|
||||
def test_add_new_model(client):
|
||||
# @pytest.mark.skip(reason="hitting yaml load issues on circle-ci")
|
||||
def test_add_new_model(client_no_auth):
|
||||
global headers
|
||||
try:
|
||||
test_data = {
|
||||
|
@ -135,15 +175,15 @@ def test_add_new_model(client):
|
|||
"description": "this is a test openai model"
|
||||
}
|
||||
}
|
||||
client.post("/model/new", json=test_data, headers=headers)
|
||||
response = client.get("/model/info", headers=headers)
|
||||
client_no_auth.post("/model/new", json=test_data, headers=headers)
|
||||
response = client_no_auth.get("/model/info", headers=headers)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(f"response: {result}")
|
||||
model_info = None
|
||||
for m in result["data"]:
|
||||
if m["id"]["model_name"] == "test_openai_models":
|
||||
model_info = m["id"]["model_info"]
|
||||
if m["model_name"] == "test_openai_models":
|
||||
model_info = m["model_info"]
|
||||
assert model_info["description"] == "this is a test openai model"
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}")
|
||||
|
@ -164,10 +204,9 @@ class MyCustomHandler(CustomLogger):
|
|||
customHandler = MyCustomHandler()
|
||||
|
||||
|
||||
def test_chat_completion_optional_params(client):
|
||||
def test_chat_completion_optional_params(client_no_auth):
|
||||
# [PROXY: PROD TEST] - DO NOT DELETE
|
||||
# This tests if all the /chat/completion params are passed to litellm
|
||||
|
||||
try:
|
||||
# Your test data
|
||||
litellm.set_verbose=True
|
||||
|
@ -185,7 +224,7 @@ def test_chat_completion_optional_params(client):
|
|||
|
||||
litellm.callbacks = [customHandler]
|
||||
print("testing proxy server: optional params")
|
||||
response = client.post("/v1/chat/completions", json=test_data, headers=headers)
|
||||
response = client_no_auth.post("/v1/chat/completions", json=test_data)
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
|
@ -217,6 +256,29 @@ def test_load_router_config():
|
|||
print(result)
|
||||
assert len(result[1]) == 2
|
||||
|
||||
# tests for litellm.cache set from config
|
||||
print("testing reading proxy config for cache")
|
||||
litellm.cache = None
|
||||
load_router_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/cache_no_params.yaml"
|
||||
)
|
||||
assert litellm.cache is not None
|
||||
assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy
|
||||
assert litellm.cache.supported_call_types == ['completion', 'acompletion', 'embedding', 'aembedding'] # init with all call types
|
||||
|
||||
print("testing reading proxy config for cache with params")
|
||||
load_router_config(
|
||||
router=None,
|
||||
config_file_path=f"{filepath}/example_config_yaml/cache_with_params.yaml"
|
||||
)
|
||||
assert litellm.cache is not None
|
||||
print(litellm.cache)
|
||||
print(litellm.cache.supported_call_types)
|
||||
print(vars(litellm.cache.cache))
|
||||
assert "redis_client" in vars(litellm.cache.cache) # it should default to redis on proxy
|
||||
assert litellm.cache.supported_call_types == ['embedding', 'aembedding'] # init with all call types
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail("Proxy: Got exception reading config", e)
|
||||
# test_load_router_config()
|
||||
# test_load_router_config()
|
|
@ -37,6 +37,8 @@ async def wrapper_startup_event():
|
|||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
def client():
|
||||
from litellm.proxy.proxy_server import cleanup_router_config_variables
|
||||
cleanup_router_config_variables()
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
|
@ -69,6 +71,38 @@ def test_add_new_key(client):
|
|||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||
|
||||
|
||||
def test_update_new_key(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"models": ["gpt-3.5-turbo", "gpt-4", "claude-2", "azure-model"],
|
||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
"duration": "20m"
|
||||
}
|
||||
print("testing proxy server")
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
response = client.post("/key/generate", json=test_data, headers=headers)
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 200
|
||||
result = response.json()
|
||||
assert result["key"].startswith("sk-")
|
||||
def _post_data():
|
||||
json_data = {'models': ['bedrock-models'], "key": result["key"]}
|
||||
response = client.post("/key/update", json=json_data, headers=headers)
|
||||
print(f"response text: {response.text}")
|
||||
assert response.status_code == 200
|
||||
return response
|
||||
_post_data()
|
||||
print(f"Received response: {result}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"LiteLLM Proxy test failed. Exception: {str(e)}")
|
||||
|
||||
# # Run the test - only runs via pytest
|
||||
|
||||
|
||||
|
|
|
@ -366,69 +366,12 @@ def test_function_calling():
|
|||
}
|
||||
]
|
||||
|
||||
router = Router(model_list=model_list, routing_strategy="latency-based-routing")
|
||||
router = Router(model_list=model_list)
|
||||
response = router.completion(model="gpt-3.5-turbo-0613", messages=messages, functions=functions)
|
||||
router.reset()
|
||||
print(response)
|
||||
|
||||
def test_acompletion_on_router():
|
||||
# tests acompletion + caching on router
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||
]
|
||||
start_time = time.time()
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
cache_responses=True,
|
||||
timeout=30,
|
||||
routing_strategy="simple-shuffle")
|
||||
async def get_response():
|
||||
print("Testing acompletion + caching on router")
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response1: {response1}")
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response2: {response2}")
|
||||
assert response1.id == response2.id
|
||||
assert len(response1.choices[0].message.content) > 0
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
asyncio.run(get_response())
|
||||
router.reset()
|
||||
except litellm.Timeout as e:
|
||||
end_time = time.time()
|
||||
print(f"timeout error occurred: {end_time - start_time}")
|
||||
pass
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
test_acompletion_on_router()
|
||||
# test_acompletion_on_router()
|
||||
|
||||
def test_function_calling_on_router():
|
||||
try:
|
||||
|
@ -507,7 +450,6 @@ def test_aembedding_on_router():
|
|||
model="text-embedding-ada-002",
|
||||
input=["good morning from litellm 2"],
|
||||
)
|
||||
print("sync embedding response: ", response)
|
||||
router.reset()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
@ -591,6 +533,30 @@ def test_bedrock_on_router():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
# test_bedrock_on_router()
|
||||
|
||||
# test openai-compatible endpoint
|
||||
@pytest.mark.asyncio
|
||||
async def test_mistral_on_router():
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "mistral/mistral-medium",
|
||||
},
|
||||
},
|
||||
]
|
||||
router = Router(model_list=model_list)
|
||||
response = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello from litellm test",
|
||||
}
|
||||
]
|
||||
)
|
||||
print(response)
|
||||
asyncio.run(test_mistral_on_router())
|
||||
|
||||
def test_openai_completion_on_router():
|
||||
# [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream
|
||||
|
|
127
litellm/tests/test_router_caching.py
Normal file
127
litellm/tests/test_router_caching.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
#### What this tests ####
|
||||
# This tests caching on the router
|
||||
import sys, os, time
|
||||
import traceback, asyncio
|
||||
import pytest
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import Router
|
||||
|
||||
## Scenarios
|
||||
## 1. 2 models - openai + azure - 1 model group "gpt-3.5-turbo",
|
||||
## 2. 2 models - openai, azure - 2 diff model groups, 1 caching group
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_caching_on_router():
|
||||
# tests acompletion + caching on router
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||
]
|
||||
start_time = time.time()
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
cache_responses=True,
|
||||
timeout=30,
|
||||
routing_strategy="simple-shuffle")
|
||||
response1 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response1: {response1}")
|
||||
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||
response2 = await router.acompletion(model="gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response2: {response2}")
|
||||
assert response1.id == response2.id
|
||||
assert len(response1.choices[0].message.content) > 0
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
router.reset()
|
||||
except litellm.Timeout as e:
|
||||
end_time = time.time()
|
||||
print(f"timeout error occurred: {end_time - start_time}")
|
||||
pass
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_caching_on_router_caching_groups():
|
||||
# tests acompletion + caching on router
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "openai-gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo-0613",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
},
|
||||
{
|
||||
"model_name": "azure-gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": os.getenv("AZURE_API_KEY"),
|
||||
"api_base": os.getenv("AZURE_API_BASE"),
|
||||
"api_version": os.getenv("AZURE_API_VERSION")
|
||||
},
|
||||
"tpm": 100000,
|
||||
"rpm": 10000,
|
||||
}
|
||||
]
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": f"write a one sentence poem {time.time()}?"}
|
||||
]
|
||||
start_time = time.time()
|
||||
router = Router(model_list=model_list,
|
||||
redis_host=os.environ["REDIS_HOST"],
|
||||
redis_password=os.environ["REDIS_PASSWORD"],
|
||||
redis_port=os.environ["REDIS_PORT"],
|
||||
cache_responses=True,
|
||||
timeout=30,
|
||||
routing_strategy="simple-shuffle",
|
||||
caching_groups=[("openai-gpt-3.5-turbo", "azure-gpt-3.5-turbo")])
|
||||
response1 = await router.acompletion(model="openai-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response1: {response1}")
|
||||
await asyncio.sleep(1) # add cache is async, async sleep for cache to get set
|
||||
response2 = await router.acompletion(model="azure-gpt-3.5-turbo", messages=messages, temperature=1)
|
||||
print(f"response2: {response2}")
|
||||
assert response1.id == response2.id
|
||||
assert len(response1.choices[0].message.content) > 0
|
||||
assert response1.choices[0].message.content == response2.choices[0].message.content
|
||||
router.reset()
|
||||
except litellm.Timeout as e:
|
||||
end_time = time.time()
|
||||
print(f"timeout error occurred: {end_time - start_time}")
|
||||
pass
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
pytest.fail(f"Error occurred: {e}")
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue