diff --git a/.circleci/config.yml b/.circleci/config.yml index 26a2ae356b..b43a8aa64c 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -125,6 +125,7 @@ jobs: pip install tiktoken pip install aiohttp pip install click + pip install "boto3==1.34.34" pip install jinja2 pip install tokenizers pip install openai @@ -287,6 +288,7 @@ jobs: pip install "pytest==7.3.1" pip install "pytest-mock==3.12.0" pip install "pytest-asyncio==0.21.1" + pip install "boto3==1.34.34" pip install mypy pip install pyarrow pip install numpydoc diff --git a/Dockerfile b/Dockerfile index c8e9956b29..bd840eaf54 100644 --- a/Dockerfile +++ b/Dockerfile @@ -62,6 +62,11 @@ COPY --from=builder /wheels/ /wheels/ RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels # Generate prisma client +ENV PRISMA_BINARY_CACHE_DIR=/app/prisma +RUN mkdir -p /.cache +RUN chmod -R 777 /.cache +RUN pip install nodejs-bin +RUN pip install prisma RUN prisma generate RUN chmod +x entrypoint.sh diff --git a/Dockerfile.database b/Dockerfile.database index 22084bab89..c995939e5b 100644 --- a/Dockerfile.database +++ b/Dockerfile.database @@ -62,6 +62,11 @@ RUN pip install PyJWT --no-cache-dir RUN chmod +x build_admin_ui.sh && ./build_admin_ui.sh # Generate prisma client +ENV PRISMA_BINARY_CACHE_DIR=/app/prisma +RUN mkdir -p /.cache +RUN chmod -R 777 /.cache +RUN pip install nodejs-bin +RUN pip install prisma RUN prisma generate RUN chmod +x entrypoint.sh diff --git a/docs/my-website/docs/completion/json_mode.md b/docs/my-website/docs/completion/json_mode.md index bf159cd07e..1d12a22ba0 100644 --- a/docs/my-website/docs/completion/json_mode.md +++ b/docs/my-website/docs/completion/json_mode.md @@ -84,17 +84,20 @@ from litellm import completion # add to env var os.environ["OPENAI_API_KEY"] = "" -messages = [{"role": "user", "content": "List 5 cookie recipes"}] +messages = [{"role": "user", "content": "List 5 important events in the XIX century"}] class CalendarEvent(BaseModel): name: str date: str participants: list[str] +class EventsList(BaseModel): + events: list[CalendarEvent] + resp = completion( model="gpt-4o-2024-08-06", messages=messages, - response_format=CalendarEvent + response_format=EventsList ) print("Received={}".format(resp)) diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index 2227b7a6b5..2a7804bfda 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -225,22 +225,336 @@ print(response) | claude-instant-1.2 | `completion('claude-instant-1.2', messages)` | `os.environ['ANTHROPIC_API_KEY']` | | claude-instant-1 | `completion('claude-instant-1', messages)` | `os.environ['ANTHROPIC_API_KEY']` | -## Passing Extra Headers to Anthropic API +## **Prompt Caching** -Pass `extra_headers: dict` to `litellm.completion` +Use Anthropic Prompt Caching -```python -from litellm import completion -messages = [{"role": "user", "content": "What is Anthropic?"}] -response = completion( - model="claude-3-5-sonnet-20240620", - messages=messages, - extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} + +[Relevant Anthropic API Docs](https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching) + +### Caching - Large Context Caching + +This example demonstrates basic Prompt Caching usage, caching the full text of the legal agreement as a prefix while keeping the user instruction uncached. + + + + +```python +response = await litellm.acompletion( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + { + "type": "text", + "text": "Here is the full text of a complex legal agreement", + "cache_control": {"type": "ephemeral"}, + }, + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement?", + }, + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, +) + +``` + + + +:::info + +LiteLLM Proxy is OpenAI compatible + +This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy + +Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy) + +::: + +```python +import openai +client = openai.AsyncOpenAI( + api_key="anything", # litellm proxy api key + base_url="http://0.0.0.0:4000" # litellm proxy base url +) + + +response = await client.chat.completions.create( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + { + "type": "text", + "text": "Here is the full text of a complex legal agreement", + "cache_control": {"type": "ephemeral"}, + }, + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement?", + }, + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, +) + +``` + + + + +### Caching - Tools definitions + +In this example, we demonstrate caching tool definitions. + +The cache_control parameter is placed on the final tool + + + + +```python +import litellm + +response = await litellm.acompletion( + model="anthropic/claude-3-5-sonnet-20240620", + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + "cache_control": {"type": "ephemeral"} + }, + } + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, ) ``` -## Advanced + + -## Usage - Function Calling +:::info + +LiteLLM Proxy is OpenAI compatible + +This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy + +Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy) + +::: + +```python +import openai +client = openai.AsyncOpenAI( + api_key="anything", # litellm proxy api key + base_url="http://0.0.0.0:4000" # litellm proxy base url +) + +response = await client.chat.completions.create( + model="anthropic/claude-3-5-sonnet-20240620", + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + "cache_control": {"type": "ephemeral"} + }, + } + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, +) +``` + + + + + +### Caching - Continuing Multi-Turn Convo + +In this example, we demonstrate how to use Prompt Caching in a multi-turn conversation. + +The cache_control parameter is placed on the system message to designate it as part of the static prefix. + +The conversation history (previous messages) is included in the messages array. The final turn is marked with cache-control, for continuing in followups. The second-to-last user message is marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + + + + +```python +import litellm + +response = await litellm.acompletion( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" + * 400, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, +) +``` + + + +:::info + +LiteLLM Proxy is OpenAI compatible + +This is an example using the OpenAI Python SDK sending a request to LiteLLM Proxy + +Assuming you have a model=`anthropic/claude-3-5-sonnet-20240620` on the [litellm proxy config.yaml](#usage-with-litellm-proxy) + +::: + +```python +import openai +client = openai.AsyncOpenAI( + api_key="anything", # litellm proxy api key + base_url="http://0.0.0.0:4000" # litellm proxy base url +) + +response = await client.chat.completions.create( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" + * 400, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, +) +``` + + + + +## **Function/Tool Calling** :::info @@ -429,6 +743,20 @@ resp = litellm.completion( print(f"\nResponse: {resp}") ``` +## **Passing Extra Headers to Anthropic API** + +Pass `extra_headers: dict` to `litellm.completion` + +```python +from litellm import completion +messages = [{"role": "user", "content": "What is Anthropic?"}] +response = completion( + model="claude-3-5-sonnet-20240620", + messages=messages, + extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} +) +``` + ## Usage - "Assistant Pre-fill" You can "put words in Claude's mouth" by including an `assistant` role message as the last item in the `messages` array. diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 485dbf892b..907dfc2337 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -393,7 +393,7 @@ response = completion( ) ``` - + ```python @@ -420,6 +420,55 @@ extra_body={ } ) +print(response) +``` + + + +1. Update config.yaml + +```yaml +model_list: + - model_name: bedrock-claude-v1 + litellm_params: + model: bedrock/anthropic.claude-instant-v1 + aws_access_key_id: os.environ/CUSTOM_AWS_ACCESS_KEY_ID + aws_secret_access_key: os.environ/CUSTOM_AWS_SECRET_ACCESS_KEY + aws_region_name: os.environ/CUSTOM_AWS_REGION_NAME + guardrailConfig: { + "guardrailIdentifier": "ff6ujrregl1q", # The identifier (ID) for the guardrail. + "guardrailVersion": "DRAFT", # The version of the guardrail. + "trace": "disabled", # The trace behavior for the guardrail. Can either be "disabled" or "enabled" + } + +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test it! + +```python + +import openai +client = openai.OpenAI( + api_key="anything", + base_url="http://0.0.0.0:4000" +) + +# request sent to model set on litellm proxy, `litellm --model` +response = client.chat.completions.create(model="bedrock-claude-v1", messages = [ + { + "role": "user", + "content": "this is a test request, write a short poem" + } +], +temperature=0.7 +) + print(response) ``` diff --git a/docs/my-website/docs/proxy/deploy.md b/docs/my-website/docs/proxy/deploy.md index 7c254ed35d..9f21068e03 100644 --- a/docs/my-website/docs/proxy/deploy.md +++ b/docs/my-website/docs/proxy/deploy.md @@ -705,6 +705,29 @@ docker run ghcr.io/berriai/litellm:main-latest \ Provide an ssl certificate when starting litellm proxy server +### 3. Providing LiteLLM config.yaml file as a s3 Object/url + +Use this if you cannot mount a config file on your deployment service (example - AWS Fargate, Railway etc) + +LiteLLM Proxy will read your config.yaml from an s3 Bucket + +Set the following .env vars +```shell +LITELLM_CONFIG_BUCKET_NAME = "litellm-proxy" # your bucket name on s3 +LITELLM_CONFIG_BUCKET_OBJECT_KEY = "litellm_proxy_config.yaml" # object key on s3 +``` + +Start litellm proxy with these env vars - litellm will read your config from s3 + +```shell +docker run --name litellm-proxy \ + -e DATABASE_URL= \ + -e LITELLM_CONFIG_BUCKET_NAME= \ + -e LITELLM_CONFIG_BUCKET_OBJECT_KEY="> \ + -p 4000:4000 \ + ghcr.io/berriai/litellm-database:main-latest +``` + ## Platform-specific Guide diff --git a/docs/my-website/docs/proxy/model_management.md b/docs/my-website/docs/proxy/model_management.md index 02ce4ba23b..a8cc66ae76 100644 --- a/docs/my-website/docs/proxy/model_management.md +++ b/docs/my-website/docs/proxy/model_management.md @@ -17,7 +17,7 @@ model_list: ## Get Model Information - `/model/info` -Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled the model_info you set and the litellm model cost map. Sensitive details like API keys are excluded for security purposes. +Retrieve detailed information about each model listed in the `/model/info` endpoint, including descriptions from the `config.yaml` file, and additional model info (e.g. max tokens, cost per input token, etc.) pulled from the model_info you set and the [litellm model cost map](https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json). Sensitive details like API keys are excluded for security purposes. - + + ```bash curl -X POST "http://0.0.0.0:4000/model/new" \ - -H "accept: application/json" \ - -H "Content-Type: application/json" \ - -d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }' + -H "accept: application/json" \ + -H "Content-Type: application/json" \ + -d '{ "model_name": "azure-gpt-turbo", "litellm_params": {"model": "azure/gpt-3.5-turbo", "api_key": "os.environ/AZURE_API_KEY", "api_base": "my-azure-api-base"} }' ``` - + + + +```yaml +model_list: + - model_name: gpt-3.5-turbo ### RECEIVED MODEL NAME ### `openai.chat.completions.create(model="gpt-3.5-turbo",...)` + litellm_params: # all params accepted by litellm.completion() - https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/types/router.py#L297 + model: azure/gpt-turbo-small-eu ### MODEL NAME sent to `litellm.completion()` ### + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: "os.environ/AZURE_API_KEY_EU" # does os.getenv("AZURE_API_KEY_EU") + rpm: 6 # [OPTIONAL] Rate limit for this deployment: in requests per minute (rpm) + model_info: + my_custom_key: my_custom_value # additional model metadata +``` + + @@ -85,4 +96,83 @@ Keep in mind that as both endpoints are in [BETA], you may need to visit the ass - Get Model Information: [Issue #933](https://github.com/BerriAI/litellm/issues/933) - Add a New Model: [Issue #964](https://github.com/BerriAI/litellm/issues/964) -Feedback on the beta endpoints is valuable and helps improve the API for all users. \ No newline at end of file +Feedback on the beta endpoints is valuable and helps improve the API for all users. + + +## Add Additional Model Information + +If you want the ability to add a display name, description, and labels for models, just use `model_info:` + +```yaml +model_list: + - model_name: "gpt-4" + litellm_params: + model: "gpt-4" + api_key: "os.environ/OPENAI_API_KEY" + model_info: # 👈 KEY CHANGE + my_custom_key: "my_custom_value" +``` + +### Usage + +1. Add additional information to model + +```yaml +model_list: + - model_name: "gpt-4" + litellm_params: + model: "gpt-4" + api_key: "os.environ/OPENAI_API_KEY" + model_info: # 👈 KEY CHANGE + my_custom_key: "my_custom_value" +``` + +2. Call with `/model/info` + +Use a key with access to the model `gpt-4`. + +```bash +curl -L -X GET 'http://0.0.0.0:4000/v1/model/info' \ +-H 'Authorization: Bearer LITELLM_KEY' \ +``` + +3. **Expected Response** + +Returned `model_info = Your custom model_info + (if exists) LITELLM MODEL INFO` + + +[**How LiteLLM Model Info is found**](https://github.com/BerriAI/litellm/blob/9b46ec05b02d36d6e4fb5c32321e51e7f56e4a6e/litellm/proxy/proxy_server.py#L7460) + +[Tell us how this can be improved!](https://github.com/BerriAI/litellm/issues) + +```bash +{ + "data": [ + { + "model_name": "gpt-4", + "litellm_params": { + "model": "gpt-4" + }, + "model_info": { + "id": "e889baacd17f591cce4c63639275ba5e8dc60765d6c553e6ee5a504b19e50ddc", + "db_model": false, + "my_custom_key": "my_custom_value", # 👈 CUSTOM INFO + "key": "gpt-4", # 👈 KEY in LiteLLM MODEL INFO/COST MAP - https://github.com/BerriAI/litellm/blob/main/model_prices_and_context_window.json + "max_tokens": 4096, + "max_input_tokens": 8192, + "max_output_tokens": 4096, + "input_cost_per_token": 3e-05, + "input_cost_per_character": null, + "input_cost_per_token_above_128k_tokens": null, + "output_cost_per_token": 6e-05, + "output_cost_per_character": null, + "output_cost_per_token_above_128k_tokens": null, + "output_cost_per_character_above_128k_tokens": null, + "output_vector_size": null, + "litellm_provider": "openai", + "mode": "chat" + } + }, + ] +} +``` diff --git a/docs/my-website/docs/proxy/pass_through.md b/docs/my-website/docs/proxy/pass_through.md index 4554f80135..bad23f0de0 100644 --- a/docs/my-website/docs/proxy/pass_through.md +++ b/docs/my-website/docs/proxy/pass_through.md @@ -193,6 +193,53 @@ curl --request POST \ }' ``` +### Use Langfuse client sdk w/ LiteLLM Key + +**Usage** + +1. Set-up yaml to pass-through langfuse /api/public/ingestion + +```yaml +general_settings: + master_key: sk-1234 + pass_through_endpoints: + - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server + target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward + auth: true # 👈 KEY CHANGE + custom_auth_parser: "langfuse" # 👈 KEY CHANGE + headers: + LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_DEV_PUBLIC_KEY" # your langfuse account public key + LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_DEV_SK_KEY" # your langfuse account secret key +``` + +2. Start proxy + +```bash +litellm --config /path/to/config.yaml +``` + +3. Test with langfuse sdk + + +```python + +from langfuse import Langfuse + +langfuse = Langfuse( + host="http://localhost:4000", # your litellm proxy endpoint + public_key="sk-1234", # your litellm proxy api key + secret_key="anything", # no key required since this is a pass through +) + +print("sending langfuse trace request") +trace = langfuse.trace(name="test-trace-litellm-proxy-passthrough") +print("flushing langfuse request") +langfuse.flush() + +print("flushed langfuse request") +``` + + ## `pass_through_endpoints` Spec on config.yaml All possible values for `pass_through_endpoints` and what they mean diff --git a/docs/my-website/docs/proxy/prometheus.md b/docs/my-website/docs/proxy/prometheus.md index 6c856f58b3..4b913d2e82 100644 --- a/docs/my-website/docs/proxy/prometheus.md +++ b/docs/my-website/docs/proxy/prometheus.md @@ -72,15 +72,15 @@ http://localhost:4000/metrics | Metric Name | Description | |----------------------|--------------------------------------| -| `deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. | +| `litellm_deployment_state` | The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage. | | `litellm_remaining_requests_metric` | Track `x-ratelimit-remaining-requests` returned from LLM API Deployment | | `litellm_remaining_tokens` | Track `x-ratelimit-remaining-tokens` return from LLM API Deployment | - `llm_deployment_success_responses` | Total number of successful LLM API calls for deployment | -| `llm_deployment_failure_responses` | Total number of failed LLM API calls for deployment | -| `llm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure | -| `llm_deployment_latency_per_output_token` | Latency per output token for deployment | -| `llm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model | -| `llm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model | + `litellm_deployment_success_responses` | Total number of successful LLM API calls for deployment | +| `litellm_deployment_failure_responses` | Total number of failed LLM API calls for deployment | +| `litellm_deployment_total_requests` | Total number of LLM API calls for deployment - success + failure | +| `litellm_deployment_latency_per_output_token` | Latency per output token for deployment | +| `litellm_deployment_successful_fallbacks` | Number of successful fallback requests from primary model -> fallback model | +| `litellm_deployment_failed_fallbacks` | Number of failed fallback requests from primary model -> fallback model | diff --git a/docs/my-website/docs/proxy/team_logging.md b/docs/my-website/docs/proxy/team_logging.md index 1cc91c2dfe..e36cb8f669 100644 --- a/docs/my-website/docs/proxy/team_logging.md +++ b/docs/my-website/docs/proxy/team_logging.md @@ -2,9 +2,9 @@ import Image from '@theme/IdealImage'; import Tabs from '@theme/Tabs'; import TabItem from '@theme/TabItem'; -# 👥📊 Team Based Logging +# 👥📊 Team/Key Based Logging -Allow each team to use their own Langfuse Project / custom callbacks +Allow each key/team to use their own Langfuse Project / custom callbacks **This allows you to do the following** ``` @@ -189,3 +189,39 @@ curl -X GET 'http://localhost:4000/team/dbe2f686-a686-4896-864a-4c3924458709/cal + + +## [BETA] Key Based Logging + +Use the `/key/generate` or `/key/update` endpoints to add logging callbacks to a specific key. + +:::info + +✨ This is an Enterprise only feature [Get Started with Enterprise here](https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat) + +::: + +```bash +curl -X POST 'http://0.0.0.0:4000/key/generate' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-d '{ + "metadata": { + "logging": { + "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary' + "callback_type": "success" # set, if required by integration - future improvement, have logging tools work for success + failure by default + "callback_vars": { + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", # [RECOMMENDED] reference key in proxy environment + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", # [RECOMMENDED] reference key in proxy environment + "langfuse_host": "https://cloud.langfuse.com" + } + } + } +}' + +``` + +--- + +Help us improve this feature, by filing a [ticket here](https://github.com/BerriAI/litellm/issues) + diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 7df5e61578..3c3e1cbf97 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -151,7 +151,7 @@ const sidebars = { }, { type: "category", - label: "Chat Completions (litellm.completion)", + label: "Chat Completions (litellm.completion + PROXY)", link: { type: "generated-index", title: "Chat Completions", diff --git a/litellm/integrations/gcs_bucket.py b/litellm/integrations/gcs_bucket.py index 46f55f8f01..be7f8e39c2 100644 --- a/litellm/integrations/gcs_bucket.py +++ b/litellm/integrations/gcs_bucket.py @@ -1,5 +1,6 @@ import json import os +import uuid from datetime import datetime from typing import Any, Dict, List, Optional, TypedDict, Union @@ -29,6 +30,8 @@ class GCSBucketPayload(TypedDict): end_time: str response_cost: Optional[float] spend_log_metadata: str + exception: Optional[str] + log_event_type: Optional[str] class GCSBucketLogger(CustomLogger): @@ -79,6 +82,7 @@ class GCSBucketLogger(CustomLogger): logging_payload: GCSBucketPayload = await self.get_gcs_payload( kwargs, response_obj, start_time_str, end_time_str ) + logging_payload["log_event_type"] = "successful_api_call" json_logged_payload = json.dumps(logging_payload) @@ -103,7 +107,56 @@ class GCSBucketLogger(CustomLogger): verbose_logger.error("GCS Bucket logging error: %s", str(e)) async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): - pass + from litellm.proxy.proxy_server import premium_user + + if premium_user is not True: + raise ValueError( + f"GCS Bucket logging is a premium feature. Please upgrade to use it. {CommonProxyErrors.not_premium_user.value}" + ) + try: + verbose_logger.debug( + "GCS Logger: async_log_failure_event logging kwargs: %s, response_obj: %s", + kwargs, + response_obj, + ) + + start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S") + end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S") + headers = await self.construct_request_headers() + + logging_payload: GCSBucketPayload = await self.get_gcs_payload( + kwargs, response_obj, start_time_str, end_time_str + ) + logging_payload["log_event_type"] = "failed_api_call" + + _litellm_params = kwargs.get("litellm_params") or {} + metadata = _litellm_params.get("metadata") or {} + + json_logged_payload = json.dumps(logging_payload) + + # Get the current date + current_date = datetime.now().strftime("%Y-%m-%d") + + # Modify the object_name to include the date-based folder + object_name = f"{current_date}/failure-{uuid.uuid4().hex}" + + if "gcs_log_id" in metadata: + object_name = metadata["gcs_log_id"] + + response = await self.async_httpx_client.post( + headers=headers, + url=f"https://storage.googleapis.com/upload/storage/v1/b/{self.BUCKET_NAME}/o?uploadType=media&name={object_name}", + data=json_logged_payload, + ) + + if response.status_code != 200: + verbose_logger.error("GCS Bucket logging error: %s", str(response.text)) + + verbose_logger.debug("GCS Bucket response %s", response) + verbose_logger.debug("GCS Bucket status code %s", response.status_code) + verbose_logger.debug("GCS Bucket response.text %s", response.text) + except Exception as e: + verbose_logger.error("GCS Bucket logging error: %s", str(e)) async def construct_request_headers(self) -> Dict[str, str]: from litellm import vertex_chat_completion @@ -139,9 +192,18 @@ class GCSBucketLogger(CustomLogger): optional_params=kwargs.get("optional_params", None), ) response_dict = {} - response_dict = convert_litellm_response_object_to_dict( - response_obj=response_obj - ) + if response_obj: + response_dict = convert_litellm_response_object_to_dict( + response_obj=response_obj + ) + + exception_str = None + + # Handle logging exception attributes + if "exception" in kwargs: + exception_str = kwargs.get("exception", "") + if not isinstance(exception_str, str): + exception_str = str(exception_str) _spend_log_payload: SpendLogsPayload = get_logging_payload( kwargs=kwargs, @@ -156,8 +218,10 @@ class GCSBucketLogger(CustomLogger): response_obj=response_dict, start_time=start_time, end_time=end_time, - spend_log_metadata=_spend_log_payload["metadata"], + spend_log_metadata=_spend_log_payload.get("metadata", ""), response_cost=kwargs.get("response_cost", None), + exception=exception_str, + log_event_type=None, ) return gcs_payload diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index 864fb34e20..d6c235d0cb 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -605,6 +605,12 @@ class LangFuseLogger: if "cache_key" in litellm.langfuse_default_tags: _hidden_params = metadata.get("hidden_params", {}) or {} _cache_key = _hidden_params.get("cache_key", None) + if _cache_key is None: + # fallback to using "preset_cache_key" + _preset_cache_key = kwargs.get("litellm_params", {}).get( + "preset_cache_key", None + ) + _cache_key = _preset_cache_key tags.append(f"cache_key:{_cache_key}") return tags @@ -676,7 +682,6 @@ def log_provider_specific_information_as_span( Returns: None """ - from litellm.proxy.proxy_server import premium_user _hidden_params = clean_metadata.get("hidden_params", None) if _hidden_params is None: diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 8797807ac6..08431fd7af 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -141,42 +141,42 @@ class PrometheusLogger(CustomLogger): ] # Metric for deployment state - self.deployment_state = Gauge( - "deployment_state", + self.litellm_deployment_state = Gauge( + "litellm_deployment_state", "LLM Deployment Analytics - The state of the deployment: 0 = healthy, 1 = partial outage, 2 = complete outage", labelnames=_logged_llm_labels, ) - self.llm_deployment_success_responses = Counter( - name="llm_deployment_success_responses", + self.litellm_deployment_success_responses = Counter( + name="litellm_deployment_success_responses", documentation="LLM Deployment Analytics - Total number of successful LLM API calls via litellm", labelnames=_logged_llm_labels, ) - self.llm_deployment_failure_responses = Counter( - name="llm_deployment_failure_responses", + self.litellm_deployment_failure_responses = Counter( + name="litellm_deployment_failure_responses", documentation="LLM Deployment Analytics - Total number of failed LLM API calls via litellm", labelnames=_logged_llm_labels, ) - self.llm_deployment_total_requests = Counter( - name="llm_deployment_total_requests", + self.litellm_deployment_total_requests = Counter( + name="litellm_deployment_total_requests", documentation="LLM Deployment Analytics - Total number of LLM API calls via litellm - success + failure", labelnames=_logged_llm_labels, ) # Deployment Latency tracking - self.llm_deployment_latency_per_output_token = Histogram( - name="llm_deployment_latency_per_output_token", + self.litellm_deployment_latency_per_output_token = Histogram( + name="litellm_deployment_latency_per_output_token", documentation="LLM Deployment Analytics - Latency per output token", labelnames=_logged_llm_labels, ) - self.llm_deployment_successful_fallbacks = Counter( - "llm_deployment_successful_fallbacks", + self.litellm_deployment_successful_fallbacks = Counter( + "litellm_deployment_successful_fallbacks", "LLM Deployment Analytics - Number of successful fallback requests from primary model -> fallback model", ["primary_model", "fallback_model"], ) - self.llm_deployment_failed_fallbacks = Counter( - "llm_deployment_failed_fallbacks", + self.litellm_deployment_failed_fallbacks = Counter( + "litellm_deployment_failed_fallbacks", "LLM Deployment Analytics - Number of failed fallback requests from primary model -> fallback model", ["primary_model", "fallback_model"], ) @@ -358,14 +358,14 @@ class PrometheusLogger(CustomLogger): api_provider=llm_provider, ) - self.llm_deployment_failure_responses.labels( + self.litellm_deployment_failure_responses.labels( litellm_model_name=litellm_model_name, model_id=model_id, api_base=api_base, api_provider=llm_provider, ).inc() - self.llm_deployment_total_requests.labels( + self.litellm_deployment_total_requests.labels( litellm_model_name=litellm_model_name, model_id=model_id, api_base=api_base, @@ -438,14 +438,14 @@ class PrometheusLogger(CustomLogger): api_provider=llm_provider, ) - self.llm_deployment_success_responses.labels( + self.litellm_deployment_success_responses.labels( litellm_model_name=litellm_model_name, model_id=model_id, api_base=api_base, api_provider=llm_provider, ).inc() - self.llm_deployment_total_requests.labels( + self.litellm_deployment_total_requests.labels( litellm_model_name=litellm_model_name, model_id=model_id, api_base=api_base, @@ -475,7 +475,7 @@ class PrometheusLogger(CustomLogger): latency_per_token = None if output_tokens is not None and output_tokens > 0: latency_per_token = _latency_seconds / output_tokens - self.llm_deployment_latency_per_output_token.labels( + self.litellm_deployment_latency_per_output_token.labels( litellm_model_name=litellm_model_name, model_id=model_id, api_base=api_base, @@ -497,7 +497,7 @@ class PrometheusLogger(CustomLogger): kwargs, ) _new_model = kwargs.get("model") - self.llm_deployment_successful_fallbacks.labels( + self.litellm_deployment_successful_fallbacks.labels( primary_model=original_model_group, fallback_model=_new_model ).inc() @@ -508,11 +508,11 @@ class PrometheusLogger(CustomLogger): kwargs, ) _new_model = kwargs.get("model") - self.llm_deployment_failed_fallbacks.labels( + self.litellm_deployment_failed_fallbacks.labels( primary_model=original_model_group, fallback_model=_new_model ).inc() - def set_deployment_state( + def set_litellm_deployment_state( self, state: int, litellm_model_name: str, @@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger): api_base: str, api_provider: str, ): - self.deployment_state.labels( + self.litellm_deployment_state.labels( litellm_model_name, model_id, api_base, api_provider ).set(state) @@ -531,7 +531,7 @@ class PrometheusLogger(CustomLogger): api_base: str, api_provider: str, ): - self.set_deployment_state( + self.set_litellm_deployment_state( 0, litellm_model_name, model_id, api_base, api_provider ) @@ -542,7 +542,7 @@ class PrometheusLogger(CustomLogger): api_base: str, api_provider: str, ): - self.set_deployment_state( + self.set_litellm_deployment_state( 1, litellm_model_name, model_id, api_base, api_provider ) @@ -553,7 +553,7 @@ class PrometheusLogger(CustomLogger): api_base: str, api_provider: str, ): - self.set_deployment_state( + self.set_litellm_deployment_state( 2, litellm_model_name, model_id, api_base, api_provider ) diff --git a/litellm/integrations/prometheus_helpers/prometheus_api.py b/litellm/integrations/prometheus_helpers/prometheus_api.py index 86764df7dd..13ccc15620 100644 --- a/litellm/integrations/prometheus_helpers/prometheus_api.py +++ b/litellm/integrations/prometheus_helpers/prometheus_api.py @@ -41,8 +41,8 @@ async def get_fallback_metric_from_prometheus(): """ response_message = "" relevant_metrics = [ - "llm_deployment_successful_fallbacks_total", - "llm_deployment_failed_fallbacks_total", + "litellm_deployment_successful_fallbacks_total", + "litellm_deployment_failed_fallbacks_total", ] for metric in relevant_metrics: response_json = await get_metric_from_prometheus( diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index 6f05aa226e..cf58163461 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -35,6 +35,7 @@ from litellm.types.llms.anthropic import ( AnthropicResponseContentBlockText, AnthropicResponseContentBlockToolUse, AnthropicResponseUsageBlock, + AnthropicSystemMessageContent, ContentBlockDelta, ContentBlockStart, ContentBlockStop, @@ -759,6 +760,7 @@ class AnthropicChatCompletion(BaseLLM): ## CALCULATING USAGE prompt_tokens = completion_response["usage"]["input_tokens"] completion_tokens = completion_response["usage"]["output_tokens"] + _usage = completion_response["usage"] total_tokens = prompt_tokens + completion_tokens model_response.created = int(time.time()) @@ -768,6 +770,11 @@ class AnthropicChatCompletion(BaseLLM): completion_tokens=completion_tokens, total_tokens=total_tokens, ) + + if "cache_creation_input_tokens" in _usage: + usage["cache_creation_input_tokens"] = _usage["cache_creation_input_tokens"] + if "cache_read_input_tokens" in _usage: + usage["cache_read_input_tokens"] = _usage["cache_read_input_tokens"] setattr(model_response, "usage", usage) # type: ignore return model_response @@ -901,6 +908,7 @@ class AnthropicChatCompletion(BaseLLM): # Separate system prompt from rest of message system_prompt_indices = [] system_prompt = "" + anthropic_system_message_list = None for idx, message in enumerate(messages): if message["role"] == "system": valid_content: bool = False @@ -908,8 +916,23 @@ class AnthropicChatCompletion(BaseLLM): system_prompt += message["content"] valid_content = True elif isinstance(message["content"], list): - for content in message["content"]: - system_prompt += content.get("text", "") + for _content in message["content"]: + anthropic_system_message_content = ( + AnthropicSystemMessageContent( + type=_content.get("type"), + text=_content.get("text"), + ) + ) + if "cache_control" in _content: + anthropic_system_message_content["cache_control"] = ( + _content["cache_control"] + ) + + if anthropic_system_message_list is None: + anthropic_system_message_list = [] + anthropic_system_message_list.append( + anthropic_system_message_content + ) valid_content = True if valid_content: @@ -919,6 +942,10 @@ class AnthropicChatCompletion(BaseLLM): messages.pop(idx) if len(system_prompt) > 0: optional_params["system"] = system_prompt + + # Handling anthropic API Prompt Caching + if anthropic_system_message_list is not None: + optional_params["system"] = anthropic_system_message_list # Format rest of message according to anthropic guidelines try: messages = prompt_factory( @@ -954,6 +981,8 @@ class AnthropicChatCompletion(BaseLLM): else: # assume openai tool call new_tool = tool["function"] new_tool["input_schema"] = new_tool.pop("parameters") # rename key + if "cache_control" in tool: + new_tool["cache_control"] = tool["cache_control"] anthropic_tools.append(new_tool) optional_params["tools"] = anthropic_tools diff --git a/litellm/llms/base_aws_llm.py b/litellm/llms/base_aws_llm.py new file mode 100644 index 0000000000..8de42eda73 --- /dev/null +++ b/litellm/llms/base_aws_llm.py @@ -0,0 +1,218 @@ +import json +from typing import List, Optional + +import httpx + +from litellm._logging import verbose_logger +from litellm.caching import DualCache, InMemoryCache +from litellm.utils import get_secret + +from .base import BaseLLM + + +class AwsAuthError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url="https://us-west-2.console.aws.amazon.com/bedrock" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class BaseAWSLLM(BaseLLM): + def __init__(self) -> None: + self.iam_cache = DualCache() + super().__init__() + + def get_credentials( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_session_token: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + aws_sts_endpoint: Optional[str] = None, + ): + """ + Return a boto3.Credentials object + """ + import boto3 + + ## CHECK IS 'os.environ/' passed in + params_to_check: List[Optional[str]] = [ + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + aws_sts_endpoint, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + _v = get_secret(param) + if _v is not None and isinstance(_v, str): + params_to_check[i] = _v + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + aws_sts_endpoint, + ) = params_to_check + + verbose_logger.debug( + "in get credentials\n" + "aws_access_key_id=%s\n" + "aws_secret_access_key=%s\n" + "aws_session_token=%s\n" + "aws_region_name=%s\n" + "aws_session_name=%s\n" + "aws_profile_name=%s\n" + "aws_role_name=%s\n" + "aws_web_identity_token=%s\n" + "aws_sts_endpoint=%s", + aws_access_key_id, + aws_secret_access_key, + aws_session_token, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + aws_sts_endpoint, + ) + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + verbose_logger.debug( + f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}" + ) + + if aws_sts_endpoint is None: + sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com" + else: + sts_endpoint = aws_sts_endpoint + + iam_creds_cache_key = json.dumps( + { + "aws_web_identity_token": aws_web_identity_token, + "aws_role_name": aws_role_name, + "aws_session_name": aws_session_name, + "aws_region_name": aws_region_name, + "aws_sts_endpoint": sts_endpoint, + } + ) + + iam_creds_dict = self.iam_cache.get_cache(iam_creds_cache_key) + if iam_creds_dict is None: + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise AwsAuthError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client( + "sts", + region_name=aws_region_name, + endpoint_url=sts_endpoint, + ) + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + iam_creds_dict = { + "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], + "aws_secret_access_key": sts_response["Credentials"][ + "SecretAccessKey" + ], + "aws_session_token": sts_response["Credentials"]["SessionToken"], + "region_name": aws_region_name, + } + + self.iam_cache.set_cache( + key=iam_creds_cache_key, + value=json.dumps(iam_creds_dict), + ttl=3600 - 60, + ) + + session = boto3.Session(**iam_creds_dict) + + iam_creds = session.get_credentials() + + return iam_creds + elif aws_role_name is not None and aws_session_name is not None: + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + # Extract the credentials from the response and convert to Session Credentials + sts_credentials = sts_response["Credentials"] + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=sts_credentials["AccessKeyId"], + secret_key=sts_credentials["SecretAccessKey"], + token=sts_credentials["SessionToken"], + ) + return credentials + elif aws_profile_name is not None: ### CHECK SESSION ### + # uses auth values from AWS profile usually stored in ~/.aws/credentials + client = boto3.Session(profile_name=aws_profile_name) + + return client.get_credentials() + elif ( + aws_access_key_id is not None + and aws_secret_access_key is not None + and aws_session_token is not None + ): ### CHECK FOR AWS SESSION TOKEN ### + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=aws_access_key_id, + secret_key=aws_secret_access_key, + token=aws_session_token, + ) + return credentials + else: + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + + return session.get_credentials() diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index ffc096f762..73387212ff 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -57,6 +57,7 @@ from litellm.utils import ( ) from .base import BaseLLM +from .base_aws_llm import BaseAWSLLM from .bedrock import BedrockError, ModelResponseIterator, convert_messages_to_prompt from .prompt_templates.factory import ( _bedrock_converse_messages_pt, @@ -87,7 +88,6 @@ BEDROCK_CONVERSE_MODELS = [ ] -iam_cache = DualCache() _response_stream_shape_cache = None bedrock_tool_name_mappings: InMemoryCache = InMemoryCache( max_size_in_memory=50, default_ttl=600 @@ -312,7 +312,7 @@ def make_sync_call( return completion_stream -class BedrockLLM(BaseLLM): +class BedrockLLM(BaseAWSLLM): """ Example call @@ -380,183 +380,6 @@ class BedrockLLM(BaseLLM): prompt += f"{message['content']}" return prompt, chat_history # type: ignore - def get_credentials( - self, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_session_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - aws_role_name: Optional[str] = None, - aws_web_identity_token: Optional[str] = None, - aws_sts_endpoint: Optional[str] = None, - ): - """ - Return a boto3.Credentials object - """ - import boto3 - - print_verbose( - f"Boto3 get_credentials called variables passed to function {locals()}" - ) - - ## CHECK IS 'os.environ/' passed in - params_to_check: List[Optional[str]] = [ - aws_access_key_id, - aws_secret_access_key, - aws_session_token, - aws_region_name, - aws_session_name, - aws_profile_name, - aws_role_name, - aws_web_identity_token, - aws_sts_endpoint, - ] - - # Iterate over parameters and update if needed - for i, param in enumerate(params_to_check): - if param and param.startswith("os.environ/"): - _v = get_secret(param) - if _v is not None and isinstance(_v, str): - params_to_check[i] = _v - # Assign updated values back to parameters - ( - aws_access_key_id, - aws_secret_access_key, - aws_session_token, - aws_region_name, - aws_session_name, - aws_profile_name, - aws_role_name, - aws_web_identity_token, - aws_sts_endpoint, - ) = params_to_check - - ### CHECK STS ### - if ( - aws_web_identity_token is not None - and aws_role_name is not None - and aws_session_name is not None - ): - print_verbose( - f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}" - ) - - if aws_sts_endpoint is None: - sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com" - else: - sts_endpoint = aws_sts_endpoint - - iam_creds_cache_key = json.dumps( - { - "aws_web_identity_token": aws_web_identity_token, - "aws_role_name": aws_role_name, - "aws_session_name": aws_session_name, - "aws_region_name": aws_region_name, - "aws_sts_endpoint": sts_endpoint, - } - ) - - iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key) - if iam_creds_dict is None: - oidc_token = get_secret(aws_web_identity_token) - - if oidc_token is None: - raise BedrockError( - message="OIDC token could not be retrieved from secret manager.", - status_code=401, - ) - - sts_client = boto3.client( - "sts", - region_name=aws_region_name, - endpoint_url=sts_endpoint, - ) - - # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html - sts_response = sts_client.assume_role_with_web_identity( - RoleArn=aws_role_name, - RoleSessionName=aws_session_name, - WebIdentityToken=oidc_token, - DurationSeconds=3600, - ) - - iam_creds_dict = { - "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], - "aws_secret_access_key": sts_response["Credentials"][ - "SecretAccessKey" - ], - "aws_session_token": sts_response["Credentials"]["SessionToken"], - "region_name": aws_region_name, - } - - iam_cache.set_cache( - key=iam_creds_cache_key, - value=json.dumps(iam_creds_dict), - ttl=3600 - 60, - ) - - session = boto3.Session(**iam_creds_dict) - - iam_creds = session.get_credentials() - - return iam_creds - elif aws_role_name is not None and aws_session_name is not None: - print_verbose( - f"Using STS Client AWS aws_role_name: {aws_role_name} aws_session_name: {aws_session_name}" - ) - sts_client = boto3.client( - "sts", - aws_access_key_id=aws_access_key_id, # [OPTIONAL] - aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] - ) - - sts_response = sts_client.assume_role( - RoleArn=aws_role_name, RoleSessionName=aws_session_name - ) - - # Extract the credentials from the response and convert to Session Credentials - sts_credentials = sts_response["Credentials"] - from botocore.credentials import Credentials - - credentials = Credentials( - access_key=sts_credentials["AccessKeyId"], - secret_key=sts_credentials["SecretAccessKey"], - token=sts_credentials["SessionToken"], - ) - return credentials - elif aws_profile_name is not None: ### CHECK SESSION ### - # uses auth values from AWS profile usually stored in ~/.aws/credentials - print_verbose(f"Using AWS profile: {aws_profile_name}") - client = boto3.Session(profile_name=aws_profile_name) - - return client.get_credentials() - elif ( - aws_access_key_id is not None - and aws_secret_access_key is not None - and aws_session_token is not None - ): ### CHECK FOR AWS SESSION TOKEN ### - print_verbose(f"Using AWS Session Token: {aws_session_token}") - from botocore.credentials import Credentials - - credentials = Credentials( - access_key=aws_access_key_id, - secret_key=aws_secret_access_key, - token=aws_session_token, - ) - return credentials - else: - print_verbose("Using Default AWS Session") - session = boto3.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, - ) - - return session.get_credentials() - def process_response( self, model: str, @@ -1055,8 +878,8 @@ class BedrockLLM(BaseLLM): }, ) raise BedrockError( - status_code=400, - message="Bedrock HTTPX: Unsupported provider={}, model={}".format( + status_code=404, + message="Bedrock HTTPX: Unknown provider={}, model={}".format( provider, model ), ) @@ -1414,7 +1237,7 @@ class AmazonConverseConfig: return optional_params -class BedrockConverseLLM(BaseLLM): +class BedrockConverseLLM(BaseAWSLLM): def __init__(self) -> None: super().__init__() @@ -1554,173 +1377,6 @@ class BedrockConverseLLM(BaseLLM): """ return urllib.parse.quote(model_id, safe="") - def get_credentials( - self, - aws_access_key_id: Optional[str] = None, - aws_secret_access_key: Optional[str] = None, - aws_session_token: Optional[str] = None, - aws_region_name: Optional[str] = None, - aws_session_name: Optional[str] = None, - aws_profile_name: Optional[str] = None, - aws_role_name: Optional[str] = None, - aws_web_identity_token: Optional[str] = None, - aws_sts_endpoint: Optional[str] = None, - ): - """ - Return a boto3.Credentials object - """ - import boto3 - - ## CHECK IS 'os.environ/' passed in - params_to_check: List[Optional[str]] = [ - aws_access_key_id, - aws_secret_access_key, - aws_session_token, - aws_region_name, - aws_session_name, - aws_profile_name, - aws_role_name, - aws_web_identity_token, - aws_sts_endpoint, - ] - - # Iterate over parameters and update if needed - for i, param in enumerate(params_to_check): - if param and param.startswith("os.environ/"): - _v = get_secret(param) - if _v is not None and isinstance(_v, str): - params_to_check[i] = _v - # Assign updated values back to parameters - ( - aws_access_key_id, - aws_secret_access_key, - aws_session_token, - aws_region_name, - aws_session_name, - aws_profile_name, - aws_role_name, - aws_web_identity_token, - aws_sts_endpoint, - ) = params_to_check - - ### CHECK STS ### - if ( - aws_web_identity_token is not None - and aws_role_name is not None - and aws_session_name is not None - ): - print_verbose( - f"IN Web Identity Token: {aws_web_identity_token} | Role Name: {aws_role_name} | Session Name: {aws_session_name}" - ) - - if aws_sts_endpoint is None: - sts_endpoint = f"https://sts.{aws_region_name}.amazonaws.com" - else: - sts_endpoint = aws_sts_endpoint - - iam_creds_cache_key = json.dumps( - { - "aws_web_identity_token": aws_web_identity_token, - "aws_role_name": aws_role_name, - "aws_session_name": aws_session_name, - "aws_region_name": aws_region_name, - "aws_sts_endpoint": sts_endpoint, - } - ) - - iam_creds_dict = iam_cache.get_cache(iam_creds_cache_key) - if iam_creds_dict is None: - oidc_token = get_secret(aws_web_identity_token) - - if oidc_token is None: - raise BedrockError( - message="OIDC token could not be retrieved from secret manager.", - status_code=401, - ) - - sts_client = boto3.client( - "sts", - region_name=aws_region_name, - endpoint_url=sts_endpoint, - ) - - # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html - # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html - sts_response = sts_client.assume_role_with_web_identity( - RoleArn=aws_role_name, - RoleSessionName=aws_session_name, - WebIdentityToken=oidc_token, - DurationSeconds=3600, - ) - - iam_creds_dict = { - "aws_access_key_id": sts_response["Credentials"]["AccessKeyId"], - "aws_secret_access_key": sts_response["Credentials"][ - "SecretAccessKey" - ], - "aws_session_token": sts_response["Credentials"]["SessionToken"], - "region_name": aws_region_name, - } - - iam_cache.set_cache( - key=iam_creds_cache_key, - value=json.dumps(iam_creds_dict), - ttl=3600 - 60, - ) - - session = boto3.Session(**iam_creds_dict) - - iam_creds = session.get_credentials() - - return iam_creds - elif aws_role_name is not None and aws_session_name is not None: - sts_client = boto3.client( - "sts", - aws_access_key_id=aws_access_key_id, # [OPTIONAL] - aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] - ) - - sts_response = sts_client.assume_role( - RoleArn=aws_role_name, RoleSessionName=aws_session_name - ) - - # Extract the credentials from the response and convert to Session Credentials - sts_credentials = sts_response["Credentials"] - from botocore.credentials import Credentials - - credentials = Credentials( - access_key=sts_credentials["AccessKeyId"], - secret_key=sts_credentials["SecretAccessKey"], - token=sts_credentials["SessionToken"], - ) - return credentials - elif aws_profile_name is not None: ### CHECK SESSION ### - # uses auth values from AWS profile usually stored in ~/.aws/credentials - client = boto3.Session(profile_name=aws_profile_name) - - return client.get_credentials() - elif ( - aws_access_key_id is not None - and aws_secret_access_key is not None - and aws_session_token is not None - ): ### CHECK FOR AWS SESSION TOKEN ### - from botocore.credentials import Credentials - - credentials = Credentials( - access_key=aws_access_key_id, - secret_key=aws_secret_access_key, - token=aws_session_token, - ) - return credentials - else: - session = boto3.Session( - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, - ) - - return session.get_credentials() - async def async_streaming( self, model: str, diff --git a/litellm/llms/ollama.py b/litellm/llms/ollama.py index 6b984e1d82..f699cf0f5f 100644 --- a/litellm/llms/ollama.py +++ b/litellm/llms/ollama.py @@ -601,12 +601,13 @@ def ollama_embeddings( ): return asyncio.run( ollama_aembeddings( - api_base, - model, - prompts, - optional_params, - logging_obj, - model_response, - encoding, + api_base=api_base, + model=model, + prompts=prompts, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging_obj, + encoding=encoding, ) + ) diff --git a/litellm/llms/ollama_chat.py b/litellm/llms/ollama_chat.py index b0dd5d905a..ea84fa95cf 100644 --- a/litellm/llms/ollama_chat.py +++ b/litellm/llms/ollama_chat.py @@ -356,6 +356,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj): "json": data, "method": "POST", "timeout": litellm.request_timeout, + "follow_redirects": True } if api_key is not None: _request["headers"] = {"Authorization": "Bearer {}".format(api_key)} diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 7c3c7e80fb..f81515e98d 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1224,6 +1224,19 @@ def convert_to_anthropic_tool_invoke( return anthropic_tool_invoke +def add_cache_control_to_content( + anthropic_content_element: Union[ + dict, AnthropicMessagesImageParam, AnthropicMessagesTextParam + ], + orignal_content_element: dict, +): + if "cache_control" in orignal_content_element: + anthropic_content_element["cache_control"] = orignal_content_element[ + "cache_control" + ] + return anthropic_content_element + + def anthropic_messages_pt( messages: list, model: str, @@ -1264,18 +1277,31 @@ def anthropic_messages_pt( image_chunk = convert_to_anthropic_image_obj( m["image_url"]["url"] ) - user_content.append( - AnthropicMessagesImageParam( - type="image", - source=AnthropicImageParamSource( - type="base64", - media_type=image_chunk["media_type"], - data=image_chunk["data"], - ), - ) + + _anthropic_content_element = AnthropicMessagesImageParam( + type="image", + source=AnthropicImageParamSource( + type="base64", + media_type=image_chunk["media_type"], + data=image_chunk["data"], + ), ) + + anthropic_content_element = add_cache_control_to_content( + anthropic_content_element=_anthropic_content_element, + orignal_content_element=m, + ) + user_content.append(anthropic_content_element) elif m.get("type", "") == "text": - user_content.append({"type": "text", "text": m["text"]}) + _anthropic_text_content_element = { + "type": "text", + "text": m["text"], + } + anthropic_content_element = add_cache_control_to_content( + anthropic_content_element=_anthropic_text_content_element, + orignal_content_element=m, + ) + user_content.append(anthropic_content_element) elif ( messages[msg_i]["role"] == "tool" or messages[msg_i]["role"] == "function" @@ -1306,6 +1332,10 @@ def anthropic_messages_pt( anthropic_message = AnthropicMessagesTextParam( type="text", text=m.get("text") ) + anthropic_message = add_cache_control_to_content( + anthropic_content_element=anthropic_message, + orignal_content_element=m, + ) assistant_content.append(anthropic_message) elif ( "content" in messages[msg_i] @@ -1313,9 +1343,17 @@ def anthropic_messages_pt( and len(messages[msg_i]["content"]) > 0 # don't pass empty text blocks. anthropic api raises errors. ): - assistant_content.append( - {"type": "text", "text": messages[msg_i]["content"]} + + _anthropic_text_content_element = { + "type": "text", + "text": messages[msg_i]["content"], + } + + anthropic_content_element = add_cache_control_to_content( + anthropic_content_element=_anthropic_text_content_element, + orignal_content_element=messages[msg_i], ) + assistant_content.append(anthropic_content_element) if messages[msg_i].get( "tool_calls", [] @@ -1701,12 +1739,14 @@ def cohere_messages_pt_v2( assistant_tool_calls: List[ToolCallObject] = [] ## MERGE CONSECUTIVE ASSISTANT CONTENT ## while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": - assistant_text = ( - messages[msg_i].get("content") or "" - ) # either string or none - if assistant_text: - assistant_content += assistant_text - + if isinstance(messages[msg_i]["content"], list): + for m in messages[msg_i]["content"]: + if m.get("type", "") == "text": + assistant_content += m["text"] + elif messages[msg_i].get("content") is not None and isinstance( + messages[msg_i]["content"], str + ): + assistant_content += messages[msg_i]["content"] if messages[msg_i].get( "tool_calls", [] ): # support assistant tool invoke conversion diff --git a/litellm/llms/sagemaker.py b/litellm/llms/sagemaker.py index d16d2bd11b..32146b9cae 100644 --- a/litellm/llms/sagemaker.py +++ b/litellm/llms/sagemaker.py @@ -7,16 +7,38 @@ import traceback import types from copy import deepcopy from enum import Enum -from typing import Any, Callable, Optional +from functools import partial +from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union import httpx # type: ignore import requests # type: ignore import litellm -from litellm.utils import EmbeddingResponse, ModelResponse, Usage, get_secret +from litellm._logging import verbose_logger +from litellm.llms.custom_httpx.http_handler import ( + AsyncHTTPHandler, + HTTPHandler, + _get_async_httpx_client, + _get_httpx_client, +) +from litellm.types.llms.openai import ( + ChatCompletionToolCallChunk, + ChatCompletionUsageBlock, +) +from litellm.types.utils import GenericStreamingChunk as GChunk +from litellm.utils import ( + CustomStreamWrapper, + EmbeddingResponse, + ModelResponse, + Usage, + get_secret, +) +from .base_aws_llm import BaseAWSLLM from .prompt_templates.factory import custom_prompt, prompt_factory +_response_stream_shape_cache = None + class SagemakerError(Exception): def __init__(self, status_code, message): @@ -31,73 +53,6 @@ class SagemakerError(Exception): ) # Call the base class constructor with the parameters it needs -class TokenIterator: - def __init__(self, stream, acompletion: bool = False): - if acompletion == False: - self.byte_iterator = iter(stream) - elif acompletion == True: - self.byte_iterator = stream - self.buffer = io.BytesIO() - self.read_pos = 0 - self.end_of_data = False - - def __iter__(self): - return self - - def __next__(self): - try: - while True: - self.buffer.seek(self.read_pos) - line = self.buffer.readline() - if line and line[-1] == ord("\n"): - response_obj = {"text": "", "is_finished": False} - self.read_pos += len(line) + 1 - full_line = line[:-1].decode("utf-8") - line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) - if line_data.get("generated_text", None) is not None: - self.end_of_data = True - response_obj["is_finished"] = True - response_obj["text"] = line_data["token"]["text"] - return response_obj - chunk = next(self.byte_iterator) - self.buffer.seek(0, io.SEEK_END) - self.buffer.write(chunk["PayloadPart"]["Bytes"]) - except StopIteration as e: - if self.end_of_data == True: - raise e # Re-raise StopIteration - else: - self.end_of_data = True - return "data: [DONE]" - - def __aiter__(self): - return self - - async def __anext__(self): - try: - while True: - self.buffer.seek(self.read_pos) - line = self.buffer.readline() - if line and line[-1] == ord("\n"): - response_obj = {"text": "", "is_finished": False} - self.read_pos += len(line) + 1 - full_line = line[:-1].decode("utf-8") - line_data = json.loads(full_line.lstrip("data:").rstrip("/n")) - if line_data.get("generated_text", None) is not None: - self.end_of_data = True - response_obj["is_finished"] = True - response_obj["text"] = line_data["token"]["text"] - return response_obj - chunk = await self.byte_iterator.__anext__() - self.buffer.seek(0, io.SEEK_END) - self.buffer.write(chunk["PayloadPart"]["Bytes"]) - except StopAsyncIteration as e: - if self.end_of_data == True: - raise e # Re-raise StopIteration - else: - self.end_of_data = True - return "data: [DONE]" - - class SagemakerConfig: """ Reference: https://d-uuwbxj1u4cnu.studio.us-west-2.sagemaker.aws/jupyter/default/lab/workspaces/auto-q/tree/DemoNotebooks/meta-textgeneration-llama-2-7b-SDK_1.ipynb @@ -145,439 +100,498 @@ os.environ['AWS_ACCESS_KEY_ID'] = "" os.environ['AWS_SECRET_ACCESS_KEY'] = "" """ + # set os.environ['AWS_REGION_NAME'] = +class SagemakerLLM(BaseAWSLLM): + def _load_credentials( + self, + optional_params: dict, + ): + try: + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_session_token = optional_params.pop("aws_session_token", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None) -def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - custom_prompt_dict={}, - hf_model_name=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, -): - import boto3 + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) - model_id = optional_params.pop("model_id", None) + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name - if aws_access_key_id != None: - # uses auth params passed to completion - # aws_access_key_id is not None, assume user is trying to auth using litellm.completion - client = boto3.client( - service_name="sagemaker-runtime", + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, + aws_session_token=aws_session_token, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + aws_sts_endpoint=aws_sts_endpoint, ) - else: - # aws_access_key_id is None, assume user is trying to auth using env variables - # boto3 automaticaly reads env variables + return credentials, aws_region_name - # we need to read region name from env - # I assume majority of users use .env for auth - region_name = ( - get_secret("AWS_REGION_NAME") - or aws_region_name # get region from config file if specified - or "us-west-2" # default to us-west-2 if region not specified - ) - client = boto3.client( - service_name="sagemaker-runtime", - region_name=region_name, - ) + def _prepare_request( + self, + credentials, + model: str, + data: dict, + optional_params: dict, + aws_region_name: str, + extra_headers: Optional[dict] = None, + ): + try: + import boto3 + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") - # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker - inference_params = deepcopy(optional_params) + sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name) + if optional_params.get("stream") is True: + api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations-response-stream" + else: + api_base = f"https://runtime.sagemaker.{aws_region_name}.amazonaws.com/endpoints/{model}/invocations" - ## Load Config - config = litellm.SagemakerConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v + encoded_data = json.dumps(data).encode("utf-8") + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=api_base, data=encoded_data, headers=headers + ) + sigv4.add_auth(request) + prepped_request = request.prepare() - model = model - if model in custom_prompt_dict: - # check if the model has a registered custom prompt - model_prompt_details = custom_prompt_dict[model] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages, - ) - elif hf_model_name in custom_prompt_dict: - # check if the base huggingface model has a registered custom prompt - model_prompt_details = custom_prompt_dict[hf_model_name] - prompt = custom_prompt( - role_dict=model_prompt_details.get("roles", None), - initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), - final_prompt_value=model_prompt_details.get("final_prompt_value", ""), - messages=messages, - ) - else: - if hf_model_name is None: - if "llama-2" in model.lower(): # llama-2 model - if "chat" in model.lower(): # apply llama2 chat template - hf_model_name = "meta-llama/Llama-2-7b-chat-hf" - else: # apply regular llama2 template - hf_model_name = "meta-llama/Llama-2-7b" - hf_model_name = ( - hf_model_name or model - ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) - prompt = prompt_factory(model=hf_model_name, messages=messages) - stream = inference_params.pop("stream", None) - if stream == True: - data = json.dumps( - {"inputs": prompt, "parameters": inference_params, "stream": True} - ).encode("utf-8") - if acompletion == True: - response = async_streaming( - optional_params=optional_params, - encoding=encoding, - model_response=model_response, + return prepped_request + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + custom_prompt_dict={}, + hf_model_name=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + acompletion: bool = False, + ): + + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker + credentials, aws_region_name = self._load_credentials(optional_params) + inference_params = deepcopy(optional_params) + + ## Load Config + config = litellm.SagemakerConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v + + if model in custom_prompt_dict: + # check if the model has a registered custom prompt + model_prompt_details = custom_prompt_dict[model] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + elif hf_model_name in custom_prompt_dict: + # check if the base huggingface model has a registered custom prompt + model_prompt_details = custom_prompt_dict[hf_model_name] + prompt = custom_prompt( + role_dict=model_prompt_details.get("roles", None), + initial_prompt_value=model_prompt_details.get( + "initial_prompt_value", "" + ), + final_prompt_value=model_prompt_details.get("final_prompt_value", ""), + messages=messages, + ) + else: + if hf_model_name is None: + if "llama-2" in model.lower(): # llama-2 model + if "chat" in model.lower(): # apply llama2 chat template + hf_model_name = "meta-llama/Llama-2-7b-chat-hf" + else: # apply regular llama2 template + hf_model_name = "meta-llama/Llama-2-7b" + hf_model_name = ( + hf_model_name or model + ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) + prompt = prompt_factory(model=hf_model_name, messages=messages) + stream = inference_params.pop("stream", None) + model_id = optional_params.get("model_id", None) + + if stream is True: + data = {"inputs": prompt, "parameters": inference_params, "stream": True} + prepared_request = self._prepare_request( model=model, - logging_obj=logging_obj, data=data, - model_id=model_id, - aws_secret_access_key=aws_secret_access_key, - aws_access_key_id=aws_access_key_id, + optional_params=optional_params, + credentials=credentials, aws_region_name=aws_region_name, ) - return response + if model_id is not None: + # Add model_id as InferenceComponentName header + # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html + prepared_request.headers.update( + {"X-Amzn-SageMaker-Inference-Componen": model_id} + ) - if model_id is not None: - response = client.invoke_endpoint_with_response_stream( - EndpointName=model, - InferenceComponentName=model_id, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", + if acompletion is True: + response = self.async_streaming( + prepared_request=prepared_request, + optional_params=optional_params, + encoding=encoding, + model_response=model_response, + model=model, + logging_obj=logging_obj, + data=data, + model_id=model_id, + ) + return response + else: + if stream is not None and stream == True: + sync_handler = _get_httpx_client() + sync_response = sync_handler.post( + url=prepared_request.url, + headers=prepared_request.headers, # type: ignore + json=data, + stream=stream, + ) + + if sync_response.status_code != 200: + raise SagemakerError( + status_code=sync_response.status_code, + message=sync_response.read(), + ) + + decoder = AWSEventStreamDecoder(model="") + + completion_stream = decoder.iter_bytes( + sync_response.iter_bytes(chunk_size=1024) + ) + streaming_response = CustomStreamWrapper( + completion_stream=completion_stream, + model=model, + custom_llm_provider="sagemaker", + logging_obj=logging_obj, + ) + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=streaming_response, + additional_args={"complete_input_dict": data}, ) - else: - response = client.invoke_endpoint_with_response_stream( - EndpointName=model, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", - ) - return response["Body"] - elif acompletion == True: + return streaming_response + + # Non-Streaming Requests _data = {"inputs": prompt, "parameters": inference_params} - return async_completion( - optional_params=optional_params, - encoding=encoding, - model_response=model_response, + prepared_request = self._prepare_request( model=model, - logging_obj=logging_obj, data=_data, - model_id=model_id, - aws_secret_access_key=aws_secret_access_key, - aws_access_key_id=aws_access_key_id, + optional_params=optional_params, + credentials=credentials, aws_region_name=aws_region_name, ) - data = json.dumps({"inputs": prompt, "parameters": inference_params}).encode( - "utf-8" - ) - ## COMPLETION CALL - try: - if model_id is not None: - ## LOGGING - request_str = f""" - response = client.invoke_endpoint( - EndpointName={model}, - InferenceComponentName={model_id}, - ContentType="application/json", - Body={data}, # type: ignore - CustomAttributes="accept_eula=true", + + # Async completion + if acompletion == True: + return self.async_completion( + prepared_request=prepared_request, + model_response=model_response, + encoding=encoding, + model=model, + logging_obj=logging_obj, + data=_data, + model_id=model_id, ) - """ # type: ignore - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - "hf_model_name": hf_model_name, - }, - ) - response = client.invoke_endpoint( - EndpointName=model, - InferenceComponentName=model_id, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", - ) - else: - ## LOGGING - request_str = f""" - response = client.invoke_endpoint( - EndpointName={model}, - ContentType="application/json", - Body={data}, # type: ignore - CustomAttributes="accept_eula=true", - ) - """ # type: ignore - logging_obj.pre_call( - input=prompt, - api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - "hf_model_name": hf_model_name, - }, - ) - response = client.invoke_endpoint( - EndpointName=model, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", - ) - except Exception as e: - status_code = ( - getattr(e, "response", {}) - .get("ResponseMetadata", {}) - .get("HTTPStatusCode", 500) - ) - error_message = ( - getattr(e, "response", {}).get("Error", {}).get("Message", str(e)) - ) - if "Inference Component Name header is required" in error_message: - error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" - raise SagemakerError(status_code=status_code, message=error_message) - - response = response["Body"].read().decode("utf8") - ## LOGGING - logging_obj.post_call( - input=prompt, - api_key="", - original_response=response, - additional_args={"complete_input_dict": data}, - ) - print_verbose(f"raw model_response: {response}") - ## RESPONSE OBJECT - completion_response = json.loads(response) - try: - if isinstance(completion_response, list): - completion_response_choices = completion_response[0] - else: - completion_response_choices = completion_response - completion_output = "" - if "generation" in completion_response_choices: - completion_output += completion_response_choices["generation"] - elif "generated_text" in completion_response_choices: - completion_output += completion_response_choices["generated_text"] - - # check if the prompt template is part of output, if so - filter it out - if completion_output.startswith(prompt) and "" in prompt: - completion_output = completion_output.replace(prompt, "", 1) - - model_response.choices[0].message.content = completion_output # type: ignore - except: - raise SagemakerError( - message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", - status_code=500, - ) - - ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode(model_response["choices"][0]["message"].get("content", "")) - ) - - model_response.created = int(time.time()) - model_response.model = model - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response - - -async def async_streaming( - optional_params, - encoding, - model_response: ModelResponse, - model: str, - model_id: Optional[str], - logging_obj: Any, - data, - aws_secret_access_key: Optional[str], - aws_access_key_id: Optional[str], - aws_region_name: Optional[str], -): - """ - Use aioboto3 - """ - import aioboto3 - - session = aioboto3.Session() - - if aws_access_key_id != None: - # uses auth params passed to completion - # aws_access_key_id is not None, assume user is trying to auth using litellm.completion - _client = session.client( - service_name="sagemaker-runtime", - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, - ) - else: - # aws_access_key_id is None, assume user is trying to auth using env variables - # boto3 automaticaly reads env variables - - # we need to read region name from env - # I assume majority of users use .env for auth - region_name = ( - get_secret("AWS_REGION_NAME") - or aws_region_name # get region from config file if specified - or "us-west-2" # default to us-west-2 if region not specified - ) - _client = session.client( - service_name="sagemaker-runtime", - region_name=region_name, - ) - - async with _client as client: + ## Non-Streaming completion CALL try: if model_id is not None: - response = await client.invoke_endpoint_with_response_stream( - EndpointName=model, - InferenceComponentName=model_id, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", + # Add model_id as InferenceComponentName header + # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html + prepared_request.headers.update( + {"X-Amzn-SageMaker-Inference-Componen": model_id} ) - else: - response = await client.invoke_endpoint_with_response_stream( - EndpointName=model, - ContentType="application/json", - Body=data, - CustomAttributes="accept_eula=true", + + ## LOGGING + timeout = 300.0 + sync_handler = _get_httpx_client() + ## LOGGING + logging_obj.pre_call( + input=[], + api_key="", + additional_args={ + "complete_input_dict": _data, + "api_base": prepared_request.url, + "headers": prepared_request.headers, + }, + ) + + # make sync httpx post request here + try: + sync_response = sync_handler.post( + url=prepared_request.url, + headers=prepared_request.headers, + json=_data, + timeout=timeout, ) + + if sync_response.status_code != 200: + raise SagemakerError( + status_code=sync_response.status_code, + message=sync_response.text, + ) + except Exception as e: + ## LOGGING + logging_obj.post_call( + input=[], + api_key="", + original_response=str(e), + additional_args={"complete_input_dict": _data}, + ) + raise e except Exception as e: - raise SagemakerError(status_code=500, message=f"{str(e)}") - response = response["Body"] - async for chunk in response: - yield chunk + verbose_logger.error("Sagemaker error %s", str(e)) + status_code = ( + getattr(e, "response", {}) + .get("ResponseMetadata", {}) + .get("HTTPStatusCode", 500) + ) + error_message = ( + getattr(e, "response", {}).get("Error", {}).get("Message", str(e)) + ) + if "Inference Component Name header is required" in error_message: + error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" + raise SagemakerError(status_code=status_code, message=error_message) - -async def async_completion( - optional_params, - encoding, - model_response: ModelResponse, - model: str, - logging_obj: Any, - data: dict, - model_id: Optional[str], - aws_secret_access_key: Optional[str], - aws_access_key_id: Optional[str], - aws_region_name: Optional[str], -): - """ - Use aioboto3 - """ - import aioboto3 - - session = aioboto3.Session() - - if aws_access_key_id != None: - # uses auth params passed to completion - # aws_access_key_id is not None, assume user is trying to auth using litellm.completion - _client = session.client( - service_name="sagemaker-runtime", - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, + completion_response = sync_response.json() + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=completion_response, + additional_args={"complete_input_dict": _data}, ) - else: - # aws_access_key_id is None, assume user is trying to auth using env variables - # boto3 automaticaly reads env variables + print_verbose(f"raw model_response: {completion_response}") + ## RESPONSE OBJECT + try: + if isinstance(completion_response, list): + completion_response_choices = completion_response[0] + else: + completion_response_choices = completion_response + completion_output = "" + if "generation" in completion_response_choices: + completion_output += completion_response_choices["generation"] + elif "generated_text" in completion_response_choices: + completion_output += completion_response_choices["generated_text"] - # we need to read region name from env - # I assume majority of users use .env for auth - region_name = ( - get_secret("AWS_REGION_NAME") - or aws_region_name # get region from config file if specified - or "us-west-2" # default to us-west-2 if region not specified - ) - _client = session.client( - service_name="sagemaker-runtime", - region_name=region_name, + # check if the prompt template is part of output, if so - filter it out + if completion_output.startswith(prompt) and "" in prompt: + completion_output = completion_output.replace(prompt, "", 1) + + model_response.choices[0].message.content = completion_output # type: ignore + except: + raise SagemakerError( + message=f"LiteLLM Error: Unable to parse sagemaker RAW RESPONSE {json.dumps(completion_response)}", + status_code=500, + ) + + ## CALCULATING USAGE - baseten charges on time, not tokens - have some mapping of cost here. + prompt_tokens = len(encoding.encode(prompt)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) ) - async with _client as client: - encoded_data = json.dumps(data).encode("utf-8") + model_response.created = int(time.time()) + model_response.model = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + setattr(model_response, "usage", usage) + return model_response + + async def make_async_call( + self, + api_base: str, + headers: dict, + data: str, + logging_obj, + client=None, + ): + try: + if client is None: + client = ( + _get_async_httpx_client() + ) # Create a new client if none provided + response = await client.post( + api_base, + headers=headers, + json=data, + stream=True, + ) + + if response.status_code != 200: + raise SagemakerError( + status_code=response.status_code, message=response.text + ) + + decoder = AWSEventStreamDecoder(model="") + completion_stream = decoder.aiter_bytes( + response.aiter_bytes(chunk_size=1024) + ) + + return completion_stream + + # LOGGING + logging_obj.post_call( + input=[], + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise SagemakerError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException as e: + raise SagemakerError(status_code=408, message="Timeout error occurred.") + except Exception as e: + raise SagemakerError(status_code=500, message=str(e)) + + async def async_streaming( + self, + prepared_request, + optional_params, + encoding, + model_response: ModelResponse, + model: str, + model_id: Optional[str], + logging_obj: Any, + data, + ): + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + self.make_async_call, + api_base=prepared_request.url, + headers=prepared_request.headers, + data=data, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="sagemaker", + logging_obj=logging_obj, + ) + + # LOGGING + logging_obj.post_call( + input=[], + api_key="", + original_response="first stream response received", + additional_args={"complete_input_dict": data}, + ) + + return streaming_response + + async def async_completion( + self, + prepared_request, + encoding, + model_response: ModelResponse, + model: str, + logging_obj: Any, + data: dict, + model_id: Optional[str], + ): + timeout = 300.0 + async_handler = _get_async_httpx_client() + ## LOGGING + logging_obj.pre_call( + input=[], + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepared_request.url, + "headers": prepared_request.headers, + }, + ) try: if model_id is not None: - ## LOGGING - request_str = f""" - response = client.invoke_endpoint( - EndpointName={model}, - InferenceComponentName={model_id}, - ContentType="application/json", - Body={data}, - CustomAttributes="accept_eula=true", + # Add model_id as InferenceComponentName header + # boto3 doc: https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_runtime_InvokeEndpoint.html + prepared_request.headers.update( + {"X-Amzn-SageMaker-Inference-Componen": model_id} ) - """ # type: ignore - logging_obj.pre_call( + # make async httpx post request here + try: + response = await async_handler.post( + url=prepared_request.url, + headers=prepared_request.headers, + json=data, + timeout=timeout, + ) + + if response.status_code != 200: + raise SagemakerError( + status_code=response.status_code, message=response.text + ) + except Exception as e: + ## LOGGING + logging_obj.post_call( input=data["inputs"], api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - }, - ) - response = await client.invoke_endpoint( - EndpointName=model, - InferenceComponentName=model_id, - ContentType="application/json", - Body=encoded_data, - CustomAttributes="accept_eula=true", - ) - else: - ## LOGGING - request_str = f""" - response = client.invoke_endpoint( - EndpointName={model}, - ContentType="application/json", - Body={data}, - CustomAttributes="accept_eula=true", - ) - """ # type: ignore - logging_obj.pre_call( - input=data["inputs"], - api_key="", - additional_args={ - "complete_input_dict": data, - "request_str": request_str, - }, - ) - response = await client.invoke_endpoint( - EndpointName=model, - ContentType="application/json", - Body=encoded_data, - CustomAttributes="accept_eula=true", + original_response=str(e), + additional_args={"complete_input_dict": data}, ) + raise e except Exception as e: error_message = f"{str(e)}" if "Inference Component Name header is required" in error_message: error_message += "\n pass in via `litellm.completion(..., model_id={InferenceComponentName})`" raise SagemakerError(status_code=500, message=error_message) - response = await response["Body"].read() - response = response.decode("utf8") + completion_response = response.json() ## LOGGING logging_obj.post_call( input=data["inputs"], @@ -586,7 +600,6 @@ async def async_completion( additional_args={"complete_input_dict": data}, ) ## RESPONSE OBJECT - completion_response = json.loads(response) try: if isinstance(completion_response, list): completion_response_choices = completion_response[0] @@ -625,141 +638,296 @@ async def async_completion( setattr(model_response, "usage", usage) return model_response + def embedding( + self, + model: str, + input: list, + model_response: EmbeddingResponse, + print_verbose: Callable, + encoding, + logging_obj, + custom_prompt_dict={}, + optional_params=None, + litellm_params=None, + logger_fn=None, + ): + """ + Supports Huggingface Jumpstart embeddings like GPT-6B + """ + ### BOTO3 INIT + import boto3 -def embedding( - model: str, - input: list, - model_response: EmbeddingResponse, - print_verbose: Callable, - encoding, - logging_obj, - custom_prompt_dict={}, - optional_params=None, - litellm_params=None, - logger_fn=None, -): - """ - Supports Huggingface Jumpstart embeddings like GPT-6B - """ - ### BOTO3 INIT - import boto3 + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) - # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them - aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) - aws_access_key_id = optional_params.pop("aws_access_key_id", None) - aws_region_name = optional_params.pop("aws_region_name", None) + if aws_access_key_id is not None: + # uses auth params passed to completion + # aws_access_key_id is not None, assume user is trying to auth using litellm.completion + client = boto3.client( + service_name="sagemaker-runtime", + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + else: + # aws_access_key_id is None, assume user is trying to auth using env variables + # boto3 automaticaly reads env variables - if aws_access_key_id is not None: - # uses auth params passed to completion - # aws_access_key_id is not None, assume user is trying to auth using litellm.completion - client = boto3.client( - service_name="sagemaker-runtime", - aws_access_key_id=aws_access_key_id, - aws_secret_access_key=aws_secret_access_key, - region_name=aws_region_name, - ) - else: - # aws_access_key_id is None, assume user is trying to auth using env variables - # boto3 automaticaly reads env variables + # we need to read region name from env + # I assume majority of users use .env for auth + region_name = ( + get_secret("AWS_REGION_NAME") + or aws_region_name # get region from config file if specified + or "us-west-2" # default to us-west-2 if region not specified + ) + client = boto3.client( + service_name="sagemaker-runtime", + region_name=region_name, + ) - # we need to read region name from env - # I assume majority of users use .env for auth - region_name = ( - get_secret("AWS_REGION_NAME") - or aws_region_name # get region from config file if specified - or "us-west-2" # default to us-west-2 if region not specified - ) - client = boto3.client( - service_name="sagemaker-runtime", - region_name=region_name, - ) + # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker + inference_params = deepcopy(optional_params) + inference_params.pop("stream", None) - # pop streaming if it's in the optional params as 'stream' raises an error with sagemaker - inference_params = deepcopy(optional_params) - inference_params.pop("stream", None) + ## Load Config + config = litellm.SagemakerConfig.get_config() + for k, v in config.items(): + if ( + k not in inference_params + ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in + inference_params[k] = v - ## Load Config - config = litellm.SagemakerConfig.get_config() - for k, v in config.items(): - if ( - k not in inference_params - ): # completion(top_k=3) > sagemaker_config(top_k=3) <- allows for dynamic variables to be passed in - inference_params[k] = v + #### HF EMBEDDING LOGIC + data = json.dumps({"text_inputs": input}).encode("utf-8") - #### HF EMBEDDING LOGIC - data = json.dumps({"text_inputs": input}).encode("utf-8") - - ## LOGGING - request_str = f""" - response = client.invoke_endpoint( - EndpointName={model}, - ContentType="application/json", - Body={data}, # type: ignore - CustomAttributes="accept_eula=true", - )""" # type: ignore - logging_obj.pre_call( - input=input, - api_key="", - additional_args={"complete_input_dict": data, "request_str": request_str}, - ) - ## EMBEDDING CALL - try: + ## LOGGING + request_str = f""" response = client.invoke_endpoint( - EndpointName=model, + EndpointName={model}, ContentType="application/json", - Body=data, + Body={data}, # type: ignore CustomAttributes="accept_eula=true", + )""" # type: ignore + logging_obj.pre_call( + input=input, + api_key="", + additional_args={"complete_input_dict": data, "request_str": request_str}, ) - except Exception as e: - status_code = ( - getattr(e, "response", {}) - .get("ResponseMetadata", {}) - .get("HTTPStatusCode", 500) - ) - error_message = ( - getattr(e, "response", {}).get("Error", {}).get("Message", str(e)) - ) - raise SagemakerError(status_code=status_code, message=error_message) + ## EMBEDDING CALL + try: + response = client.invoke_endpoint( + EndpointName=model, + ContentType="application/json", + Body=data, + CustomAttributes="accept_eula=true", + ) + except Exception as e: + status_code = ( + getattr(e, "response", {}) + .get("ResponseMetadata", {}) + .get("HTTPStatusCode", 500) + ) + error_message = ( + getattr(e, "response", {}).get("Error", {}).get("Message", str(e)) + ) + raise SagemakerError(status_code=status_code, message=error_message) - response = json.loads(response["Body"].read().decode("utf8")) - ## LOGGING - logging_obj.post_call( - input=input, - api_key="", - original_response=response, - additional_args={"complete_input_dict": data}, - ) - - print_verbose(f"raw model_response: {response}") - if "embedding" not in response: - raise SagemakerError(status_code=500, message="embedding not found in response") - embeddings = response["embedding"] - - if not isinstance(embeddings, list): - raise SagemakerError( - status_code=422, message=f"Response not in expected format - {embeddings}" + response = json.loads(response["Body"].read().decode("utf8")) + ## LOGGING + logging_obj.post_call( + input=input, + api_key="", + original_response=response, + additional_args={"complete_input_dict": data}, ) - output_data = [] - for idx, embedding in enumerate(embeddings): - output_data.append( - {"object": "embedding", "index": idx, "embedding": embedding} + print_verbose(f"raw model_response: {response}") + if "embedding" not in response: + raise SagemakerError( + status_code=500, message="embedding not found in response" + ) + embeddings = response["embedding"] + + if not isinstance(embeddings, list): + raise SagemakerError( + status_code=422, + message=f"Response not in expected format - {embeddings}", + ) + + output_data = [] + for idx, embedding in enumerate(embeddings): + output_data.append( + {"object": "embedding", "index": idx, "embedding": embedding} + ) + + model_response.object = "list" + model_response.data = output_data + model_response.model = model + + input_tokens = 0 + for text in input: + input_tokens += len(encoding.encode(text)) + + setattr( + model_response, + "usage", + Usage( + prompt_tokens=input_tokens, + completion_tokens=0, + total_tokens=input_tokens, + ), ) - model_response.object = "list" - model_response.data = output_data - model_response.model = model + return model_response - input_tokens = 0 - for text in input: - input_tokens += len(encoding.encode(text)) - setattr( - model_response, - "usage", - Usage( - prompt_tokens=input_tokens, completion_tokens=0, total_tokens=input_tokens - ), - ) +def get_response_stream_shape(): + global _response_stream_shape_cache + if _response_stream_shape_cache is None: - return model_response + from botocore.loaders import Loader + from botocore.model import ServiceModel + + loader = Loader() + sagemaker_service_dict = loader.load_service_model( + "sagemaker-runtime", "service-2" + ) + sagemaker_service_model = ServiceModel(sagemaker_service_dict) + _response_stream_shape_cache = sagemaker_service_model.shape_for( + "InvokeEndpointWithResponseStreamOutput" + ) + return _response_stream_shape_cache + + +class AWSEventStreamDecoder: + def __init__(self, model: str) -> None: + from botocore.parsers import EventStreamJSONParser + + self.model = model + self.parser = EventStreamJSONParser() + self.content_blocks: List = [] + + def _chunk_parser(self, chunk_data: dict) -> GChunk: + verbose_logger.debug("in sagemaker chunk parser, chunk_data %s", chunk_data) + _token = chunk_data.get("token", {}) or {} + _index = chunk_data.get("index", None) or 0 + is_finished = False + finish_reason = "" + + _text = _token.get("text", "") + if _text == "<|endoftext|>": + return GChunk( + text="", + index=_index, + is_finished=True, + finish_reason="stop", + ) + + return GChunk( + text=_text, + index=_index, + is_finished=is_finished, + finish_reason=finish_reason, + ) + + def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GChunk]: + """Given an iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + accumulated_json = "" + + for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + # remove data: prefix and "\n\n" at the end + message = message.replace("data:", "").replace("\n\n", "") + + # Accumulate JSON data + accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(accumulated_json) + yield self._chunk_parser(chunk_data=_data) + # Reset accumulated_json after successful parsing + accumulated_json = "" + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + continue + + # Handle any remaining data after the iterator is exhausted + if accumulated_json: + try: + _data = json.loads(accumulated_json) + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: + # Handle or log any unparseable data at the end + verbose_logger.error( + f"Warning: Unparseable JSON data remained: {accumulated_json}" + ) + + async def aiter_bytes( + self, iterator: AsyncIterator[bytes] + ) -> AsyncIterator[GChunk]: + """Given an async iterator that yields lines, iterate over it & yield every event encountered""" + from botocore.eventstream import EventStreamBuffer + + event_stream_buffer = EventStreamBuffer() + accumulated_json = "" + + async for chunk in iterator: + event_stream_buffer.add_data(chunk) + for event in event_stream_buffer: + message = self._parse_message_from_event(event) + if message: + verbose_logger.debug("sagemaker parsed chunk bytes %s", message) + # remove data: prefix and "\n\n" at the end + message = message.replace("data:", "").replace("\n\n", "") + + # Accumulate JSON data + accumulated_json += message + + # Try to parse the accumulated JSON + try: + _data = json.loads(accumulated_json) + yield self._chunk_parser(chunk_data=_data) + # Reset accumulated_json after successful parsing + accumulated_json = "" + except json.JSONDecodeError: + # If it's not valid JSON yet, continue to the next event + continue + + # Handle any remaining data after the iterator is exhausted + if accumulated_json: + try: + _data = json.loads(accumulated_json) + yield self._chunk_parser(chunk_data=_data) + except json.JSONDecodeError: + # Handle or log any unparseable data at the end + verbose_logger.error( + f"Warning: Unparseable JSON data remained: {accumulated_json}" + ) + + def _parse_message_from_event(self, event) -> Optional[str]: + response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) + + if response_dict["status_code"] != 200: + raise ValueError(f"Bad response code, expected 200: {response_dict}") + + if "chunk" in parsed_response: + chunk = parsed_response.get("chunk") + if not chunk: + return None + return chunk.get("bytes").decode() # type: ignore[no-any-return] + else: + chunk = response_dict.get("body") + if not chunk: + return None + + return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/triton.py b/litellm/llms/triton.py index 7d0338d069..14a2e828b4 100644 --- a/litellm/llms/triton.py +++ b/litellm/llms/triton.py @@ -240,10 +240,10 @@ class TritonChatCompletion(BaseLLM): handler = HTTPHandler() if stream: return self._handle_stream( - handler, api_base, data_for_triton, model, logging_obj + handler, api_base, json_data_for_triton, model, logging_obj ) else: - response = handler.post(url=api_base, data=data_for_triton, headers=headers) + response = handler.post(url=api_base, data=json_data_for_triton, headers=headers) return self._handle_response( response, model_response, logging_obj, type_of_model=type_of_model ) diff --git a/litellm/main.py b/litellm/main.py index d7a3ca996d..cf7a4a5e7e 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -95,7 +95,6 @@ from .llms import ( palm, petals, replicate, - sagemaker, together_ai, triton, vertex_ai, @@ -120,6 +119,7 @@ from .llms.prompt_templates.factory import ( prompt_factory, stringify_json_tool_call_content, ) +from .llms.sagemaker import SagemakerLLM from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion from .llms.vertex_ai_partner import VertexAIPartnerModels @@ -166,6 +166,7 @@ bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() vertex_partner_models_chat_completion = VertexAIPartnerModels() watsonxai = IBMWatsonXAI() +sagemaker_llm = SagemakerLLM() ####### COMPLETION ENDPOINTS ################ @@ -2216,7 +2217,7 @@ def completion( response = model_response elif custom_llm_provider == "sagemaker": # boto3 reads keys from .env - model_response = sagemaker.completion( + model_response = sagemaker_llm.completion( model=model, messages=messages, model_response=model_response, @@ -2230,26 +2231,13 @@ def completion( logging_obj=logging, acompletion=acompletion, ) - if ( - "stream" in optional_params and optional_params["stream"] == True - ): ## [BETA] - print_verbose(f"ENTERS SAGEMAKER CUSTOMSTREAMWRAPPER") - from .llms.sagemaker import TokenIterator - - tokenIterator = TokenIterator(model_response, acompletion=acompletion) - response = CustomStreamWrapper( - completion_stream=tokenIterator, - model=model, - custom_llm_provider="sagemaker", - logging_obj=logging, - ) + if optional_params.get("stream", False): ## LOGGING logging.post_call( input=messages, api_key=None, - original_response=response, + original_response=model_response, ) - return response ## RESPONSE OBJECT response = model_response @@ -3529,7 +3517,7 @@ def embedding( model_response=EmbeddingResponse(), ) elif custom_llm_provider == "sagemaker": - response = sagemaker.embedding( + response = sagemaker_llm.embedding( model=model, input=input, encoding=encoding, @@ -4898,7 +4886,6 @@ async def ahealth_check( verbose_logger.error( "litellm.ahealth_check(): Exception occured - {}".format(str(e)) ) - verbose_logger.debug(traceback.format_exc()) stack_trace = traceback.format_exc() if isinstance(stack_trace, str): stack_trace = stack_trace[:1000] @@ -4907,7 +4894,12 @@ async def ahealth_check( "error": "Missing `mode`. Set the `mode` for the model - https://docs.litellm.ai/docs/proxy/health#embedding-models" } - error_to_return = str(e) + " stack trace: " + stack_trace + error_to_return = ( + str(e) + + "\nHave you set 'mode' - https://docs.litellm.ai/docs/proxy/health#embedding-models" + + "\nstack trace: " + + stack_trace + ) return {"error": error_to_return} diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 455fe1e3c5..d30270c5c8 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -57,6 +57,18 @@ "supports_parallel_function_calling": true, "supports_vision": true }, + "chatgpt-4o-latest": { + "max_tokens": 4096, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000005, + "output_cost_per_token": 0.000015, + "litellm_provider": "openai", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_vision": true + }, "gpt-4o-2024-05-13": { "max_tokens": 4096, "max_input_tokens": 128000, @@ -2062,7 +2074,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/claude-3-5-sonnet@20240620": { "max_tokens": 4096, @@ -2073,7 +2086,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/claude-3-haiku@20240307": { "max_tokens": 4096, @@ -2084,7 +2098,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/claude-3-opus@20240229": { "max_tokens": 4096, @@ -2095,7 +2110,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/meta/llama3-405b-instruct-maas": { "max_tokens": 32000, @@ -4519,6 +4535,69 @@ "litellm_provider": "perplexity", "mode": "chat" }, + "perplexity/llama-3.1-70b-instruct": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-8b-instruct": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-huge-128k-online": { + "max_tokens": 127072, + "max_input_tokens": 127072, + "max_output_tokens": 127072, + "input_cost_per_token": 0.000005, + "output_cost_per_token": 0.000005, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-large-128k-online": { + "max_tokens": 127072, + "max_input_tokens": 127072, + "max_output_tokens": 127072, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-large-128k-chat": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-small-128k-chat": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-small-128k-online": { + "max_tokens": 127072, + "max_input_tokens": 127072, + "max_output_tokens": 127072, + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "perplexity", + "mode": "chat" + }, "perplexity/pplx-7b-chat": { "max_tokens": 8192, "max_input_tokens": 8192, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index e4e180727d..dfa5c16520 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,13 +1,6 @@ model_list: - - model_name: "*" + - model_name: "gpt-4" litellm_params: - model: "*" - -# general_settings: -# master_key: sk-1234 -# pass_through_endpoints: -# - path: "/api/public/ingestion" # route you want to add to LiteLLM Proxy Server -# target: "https://us.cloud.langfuse.com/api/public/ingestion" # URL this route should forward -# headers: -# LANGFUSE_PUBLIC_KEY: "os.environ/LANGFUSE_PUBLIC_KEY" # your langfuse account public key -# LANGFUSE_SECRET_KEY: "os.environ/LANGFUSE_SECRET_KEY" # your langfuse account secret key \ No newline at end of file + model: "gpt-4" + model_info: + my_custom_key: "my_custom_value" \ No newline at end of file diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 5ae149f1bd..00e78f64e6 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -12,7 +12,7 @@ import json import secrets import traceback from datetime import datetime, timedelta, timezone -from typing import Optional +from typing import Optional, Tuple from uuid import uuid4 import fastapi @@ -125,7 +125,7 @@ async def user_api_key_auth( # Check 2. FILTER IP ADDRESS await check_if_request_size_is_safe(request=request) - is_valid_ip = _check_valid_ip( + is_valid_ip, passed_in_ip = _check_valid_ip( allowed_ips=general_settings.get("allowed_ips", None), use_x_forwarded_for=general_settings.get("use_x_forwarded_for", False), request=request, @@ -134,7 +134,7 @@ async def user_api_key_auth( if not is_valid_ip: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Access forbidden: IP address not allowed.", + detail=f"Access forbidden: IP address {passed_in_ip} not allowed.", ) pass_through_endpoints: Optional[List[dict]] = general_settings.get( @@ -1251,12 +1251,12 @@ def _check_valid_ip( allowed_ips: Optional[List[str]], request: Request, use_x_forwarded_for: Optional[bool] = False, -) -> bool: +) -> Tuple[bool, Optional[str]]: """ Returns if ip is allowed or not """ if allowed_ips is None: # if not set, assume true - return True + return True, None # if general_settings.get("use_x_forwarded_for") is True then use x-forwarded-for client_ip = None @@ -1267,9 +1267,9 @@ def _check_valid_ip( # Check if IP address is allowed if client_ip not in allowed_ips: - return False + return False, client_ip - return True + return True, client_ip def get_api_key_from_custom_header( diff --git a/litellm/proxy/common_utils/load_config_utils.py b/litellm/proxy/common_utils/load_config_utils.py new file mode 100644 index 0000000000..b009695b8c --- /dev/null +++ b/litellm/proxy/common_utils/load_config_utils.py @@ -0,0 +1,56 @@ +import yaml + +from litellm._logging import verbose_proxy_logger + + +def get_file_contents_from_s3(bucket_name, object_key): + try: + # v0 rely on boto3 for authentication - allowing boto3 to handle IAM credentials etc + import tempfile + + import boto3 + from botocore.config import Config + from botocore.credentials import Credentials + + from litellm.main import bedrock_converse_chat_completion + + credentials: Credentials = bedrock_converse_chat_completion.get_credentials() + s3_client = boto3.client( + "s3", + aws_access_key_id=credentials.access_key, + aws_secret_access_key=credentials.secret_key, + aws_session_token=credentials.token, # Optional, if using temporary credentials + ) + verbose_proxy_logger.debug( + f"Retrieving {object_key} from S3 bucket: {bucket_name}" + ) + response = s3_client.get_object(Bucket=bucket_name, Key=object_key) + verbose_proxy_logger.debug(f"Response: {response}") + + # Read the file contents + file_contents = response["Body"].read().decode("utf-8") + verbose_proxy_logger.debug(f"File contents retrieved from S3") + + # Create a temporary file with YAML extension + with tempfile.NamedTemporaryFile(delete=False, suffix=".yaml") as temp_file: + temp_file.write(file_contents.encode("utf-8")) + temp_file_path = temp_file.name + verbose_proxy_logger.debug(f"File stored temporarily at: {temp_file_path}") + + # Load the YAML file content + with open(temp_file_path, "r") as yaml_file: + config = yaml.safe_load(yaml_file) + + return config + except ImportError: + # this is most likely if a user is not using the litellm docker container + verbose_proxy_logger.error(f"ImportError: {str(e)}") + pass + except Exception as e: + verbose_proxy_logger.error(f"Error retrieving file contents: {str(e)}") + return None + + +# # Example usage +# bucket_name = 'litellm-proxy' +# object_key = 'litellm_proxy_config.yaml' diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 990cb52337..dd39efd6b7 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -5,7 +5,12 @@ from fastapi import Request import litellm from litellm._logging import verbose_logger, verbose_proxy_logger -from litellm.proxy._types import CommonProxyErrors, TeamCallbackMetadata, UserAPIKeyAuth +from litellm.proxy._types import ( + AddTeamCallback, + CommonProxyErrors, + TeamCallbackMetadata, + UserAPIKeyAuth, +) from litellm.types.utils import SupportedCacheControls if TYPE_CHECKING: @@ -59,6 +64,42 @@ def safe_add_api_version_from_query_params(data: dict, request: Request): verbose_logger.error("error checking api version in query params: %s", str(e)) +def convert_key_logging_metadata_to_callback( + data: AddTeamCallback, team_callback_settings_obj: Optional[TeamCallbackMetadata] +) -> TeamCallbackMetadata: + if team_callback_settings_obj is None: + team_callback_settings_obj = TeamCallbackMetadata() + if data.callback_type == "success": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + + if data.callback_name not in team_callback_settings_obj.success_callback: + team_callback_settings_obj.success_callback.append(data.callback_name) + elif data.callback_type == "failure": + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + + if data.callback_name not in team_callback_settings_obj.failure_callback: + team_callback_settings_obj.failure_callback.append(data.callback_name) + elif data.callback_type == "success_and_failure": + if team_callback_settings_obj.success_callback is None: + team_callback_settings_obj.success_callback = [] + if team_callback_settings_obj.failure_callback is None: + team_callback_settings_obj.failure_callback = [] + if data.callback_name not in team_callback_settings_obj.success_callback: + team_callback_settings_obj.success_callback.append(data.callback_name) + + if data.callback_name in team_callback_settings_obj.failure_callback: + team_callback_settings_obj.failure_callback.append(data.callback_name) + + for var, value in data.callback_vars.items(): + if team_callback_settings_obj.callback_vars is None: + team_callback_settings_obj.callback_vars = {} + team_callback_settings_obj.callback_vars[var] = litellm.get_secret(value) + + return team_callback_settings_obj + + async def add_litellm_data_to_request( data: dict, request: Request, @@ -85,14 +126,19 @@ async def add_litellm_data_to_request( safe_add_api_version_from_query_params(data, request) + _headers = dict(request.headers) + # Include original request and headers in the data data["proxy_server_request"] = { "url": str(request.url), "method": request.method, - "headers": dict(request.headers), + "headers": _headers, "body": copy.copy(data), # use copy instead of deepcopy } + ## Forward any LLM API Provider specific headers in extra_headers + add_provider_specific_headers_to_request(data=data, headers=_headers) + ## Cache Controls headers = request.headers verbose_proxy_logger.debug("Request Headers: %s", headers) @@ -224,6 +270,7 @@ async def add_litellm_data_to_request( } # add the team-specific configs to the completion call # Team Callbacks controls + callback_settings_obj: Optional[TeamCallbackMetadata] = None if user_api_key_dict.team_metadata is not None: team_metadata = user_api_key_dict.team_metadata if "callback_settings" in team_metadata: @@ -241,17 +288,54 @@ async def add_litellm_data_to_request( } } """ - data["success_callback"] = callback_settings_obj.success_callback - data["failure_callback"] = callback_settings_obj.failure_callback + elif ( + user_api_key_dict.metadata is not None + and "logging" in user_api_key_dict.metadata + ): + for item in user_api_key_dict.metadata["logging"]: - if callback_settings_obj.callback_vars is not None: - # unpack callback_vars in data - for k, v in callback_settings_obj.callback_vars.items(): - data[k] = v + callback_settings_obj = convert_key_logging_metadata_to_callback( + data=AddTeamCallback(**item), + team_callback_settings_obj=callback_settings_obj, + ) + + if callback_settings_obj is not None: + data["success_callback"] = callback_settings_obj.success_callback + data["failure_callback"] = callback_settings_obj.failure_callback + + if callback_settings_obj.callback_vars is not None: + # unpack callback_vars in data + for k, v in callback_settings_obj.callback_vars.items(): + data[k] = v return data +def add_provider_specific_headers_to_request( + data: dict, + headers: dict, +): + ANTHROPIC_API_HEADERS = [ + "anthropic-version", + "anthropic-beta", + ] + + extra_headers = data.get("extra_headers", {}) or {} + + # boolean to indicate if a header was added + added_header = False + for header in ANTHROPIC_API_HEADERS: + if header in headers: + header_value = headers[header] + extra_headers.update({header: header_value}) + added_header = True + + if added_header is True: + data["extra_headers"] = extra_headers + + return + + def _add_otel_traceparent_to_data(data: dict, request: Request): from litellm.proxy.proxy_server import open_telemetry_logger diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 660c27f249..d25f1b9468 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -19,6 +19,9 @@ model_list: litellm_params: model: mistral/mistral-small-latest api_key: "os.environ/MISTRAL_API_KEY" + - model_name: bedrock-anthropic + litellm_params: + model: bedrock/anthropic.claude-3-sonnet-20240229-v1:0 - model_name: gemini-1.5-pro-001 litellm_params: model: vertex_ai_beta/gemini-1.5-pro-001 @@ -39,7 +42,7 @@ general_settings: litellm_settings: fallbacks: [{"gemini-1.5-pro-001": ["gpt-4o"]}] - success_callback: ["langfuse", "prometheus"] - langfuse_default_tags: ["cache_hit", "cache_key", "proxy_base_url", "user_api_key_alias", "user_api_key_user_id", "user_api_key_user_email", "user_api_key_team_alias", "semantic-similarity", "proxy_base_url"] - failure_callback: ["prometheus"] + callbacks: ["gcs_bucket"] + success_callback: ["langfuse"] + langfuse_default_tags: ["cache_hit", "cache_key", "user_api_key_alias", "user_api_key_team_alias"] cache: True diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4d141955b2..10c06b2ece 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -159,6 +159,7 @@ from litellm.proxy.common_utils.http_parsing_utils import ( check_file_size_under_limit, ) from litellm.proxy.common_utils.init_callbacks import initialize_callbacks_on_proxy +from litellm.proxy.common_utils.load_config_utils import get_file_contents_from_s3 from litellm.proxy.common_utils.openai_endpoint_utils import ( remove_sensitive_info_from_deployment, ) @@ -197,6 +198,8 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( router as pass_through_router, ) +from litellm.proxy.route_llm_request import route_request + from litellm.proxy.secret_managers.aws_secret_manager import ( load_aws_kms, load_aws_secret_manager, @@ -1444,7 +1447,18 @@ class ProxyConfig: global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger, health_check_details # Load existing config - config = await self.get_config(config_file_path=config_file_path) + if os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: + bucket_name = os.environ.get("LITELLM_CONFIG_BUCKET_NAME") + object_key = os.environ.get("LITELLM_CONFIG_BUCKET_OBJECT_KEY") + verbose_proxy_logger.debug( + "bucket_name: %s, object_key: %s", bucket_name, object_key + ) + config = get_file_contents_from_s3( + bucket_name=bucket_name, object_key=object_key + ) + else: + # default to file + config = await self.get_config(config_file_path=config_file_path) ## PRINT YAML FOR CONFIRMING IT WORKS printed_yaml = copy.deepcopy(config) printed_yaml.pop("environment_variables", None) @@ -2652,6 +2666,15 @@ async def startup_event(): ) else: await initialize(**worker_config) + elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None: + ( + llm_router, + llm_model_list, + general_settings, + ) = await proxy_config.load_config( + router=llm_router, config_file_path=worker_config + ) + else: # if not, assume it's a json string worker_config = json.loads(os.getenv("WORKER_CONFIG")) @@ -3036,68 +3059,13 @@ async def chat_completion( ### ROUTE THE REQUEST ### # Do not change this - it should be a constant time fetch - ALWAYS - router_model_names = llm_router.model_names if llm_router is not None else [] - # skip router if user passed their key - if "api_key" in data: - tasks.append(litellm.acompletion(**data)) - elif "," in data["model"] and llm_router is not None: - if ( - data.get("fastest_response", None) is not None - and data["fastest_response"] == True - ): - tasks.append(llm_router.abatch_completion_fastest_response(**data)) - else: - _models_csv_string = data.pop("model") - _models = [model.strip() for model in _models_csv_string.split(",")] - tasks.append(llm_router.abatch_completion(models=_models, **data)) - elif "user_config" in data: - # initialize a new router instance. make request using this Router - router_config = data.pop("user_config") - user_router = litellm.Router(**router_config) - tasks.append(user_router.acompletion(**data)) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - tasks.append(llm_router.acompletion(**data)) - elif ( - llm_router is not None and data["model"] in llm_router.get_model_ids() - ): # model in router model list - tasks.append(llm_router.acompletion(**data)) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - tasks.append(llm_router.acompletion(**data)) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - tasks.append(llm_router.acompletion(**data, specific_deployment=True)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and llm_router.router_general_settings.pass_through_all_models is True - ): - tasks.append(litellm.acompletion(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - tasks.append(llm_router.acompletion(**data)) - elif user_model is not None: # `litellm --model ` - tasks.append(litellm.acompletion(**data)) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "chat_completion: Invalid model name passed in model=" - + data.get("model", "") - }, - ) + llm_call = await route_request( + data=data, + route_type="acompletion", + llm_router=llm_router, + user_model=user_model, + ) + tasks.append(llm_call) # wait for call to end llm_responses = asyncio.gather( @@ -3320,58 +3288,15 @@ async def completion( ) ### ROUTE THE REQUESTs ### - router_model_names = llm_router.model_names if llm_router is not None else [] - # skip router if user passed their key - if "api_key" in data: - llm_response = asyncio.create_task(litellm.atext_completion(**data)) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - llm_response = asyncio.create_task(llm_router.atext_completion(**data)) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - llm_response = asyncio.create_task(llm_router.atext_completion(**data)) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - llm_response = asyncio.create_task( - llm_router.atext_completion(**data, specific_deployment=True) - ) - elif ( - llm_router is not None and data["model"] in llm_router.get_model_ids() - ): # model in router model list - llm_response = asyncio.create_task(llm_router.atext_completion(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and llm_router.router_general_settings.pass_through_all_models is True - ): - llm_response = asyncio.create_task(litellm.atext_completion(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - llm_response = asyncio.create_task(llm_router.atext_completion(**data)) - elif user_model is not None: # `litellm --model ` - llm_response = asyncio.create_task(litellm.atext_completion(**data)) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "completion: Invalid model name passed in model=" - + data.get("model", "") - }, - ) + llm_call = await route_request( + data=data, + route_type="atext_completion", + llm_router=llm_router, + user_model=user_model, + ) # Await the llm_response task - response = await llm_response + response = await llm_call hidden_params = getattr(response, "_hidden_params", {}) or {} model_id = hidden_params.get("model_id", None) or "" @@ -3585,59 +3510,13 @@ async def embeddings( ) ## ROUTE TO CORRECT ENDPOINT ## - # skip router if user passed their key - if "api_key" in data: - tasks.append(litellm.aembedding(**data)) - elif "user_config" in data: - # initialize a new router instance. make request using this Router - router_config = data.pop("user_config") - user_router = litellm.Router(**router_config) - tasks.append(user_router.aembedding(**data)) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - tasks.append(llm_router.aembedding(**data)) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - tasks.append( - llm_router.aembedding(**data) - ) # ensure this goes the llm_router, router will do the correct alias mapping - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - tasks.append(llm_router.aembedding(**data, specific_deployment=True)) - elif ( - llm_router is not None and data["model"] in llm_router.get_model_ids() - ): # model in router deployments, calling a specific deployment on the router - tasks.append(llm_router.aembedding(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and llm_router.router_general_settings.pass_through_all_models is True - ): - tasks.append(litellm.aembedding(**data)) - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - tasks.append(llm_router.aembedding(**data)) - elif user_model is not None: # `litellm --model ` - tasks.append(litellm.aembedding(**data)) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "embeddings: Invalid model name passed in model=" - + data.get("model", "") - }, - ) + llm_call = await route_request( + data=data, + route_type="aembedding", + llm_router=llm_router, + user_model=user_model, + ) + tasks.append(llm_call) # wait for call to end llm_responses = asyncio.gather( @@ -3768,46 +3647,13 @@ async def image_generation( ) ## ROUTE TO CORRECT ENDPOINT ## - # skip router if user passed their key - if "api_key" in data: - response = await litellm.aimage_generation(**data) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - response = await llm_router.aimage_generation(**data) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.aimage_generation( - **data, specific_deployment=True - ) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.aimage_generation( - **data - ) # ensure this goes the llm_router, router will do the correct alias mapping - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.aimage_generation(**data) - elif user_model is not None: # `litellm --model ` - response = await litellm.aimage_generation(**data) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "image_generation: Invalid model name passed in model=" - + data.get("model", "") - }, - ) + llm_call = await route_request( + data=data, + route_type="aimage_generation", + llm_router=llm_router, + user_model=user_model, + ) + response = await llm_call ### ALERTING ### asyncio.create_task( @@ -3915,44 +3761,13 @@ async def audio_speech( ) ## ROUTE TO CORRECT ENDPOINT ## - # skip router if user passed their key - if "api_key" in data: - response = await litellm.aspeech(**data) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - response = await llm_router.aspeech(**data) - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.aspeech(**data, specific_deployment=True) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.aspeech( - **data - ) # ensure this goes the llm_router, router will do the correct alias mapping - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.aspeech(**data) - elif user_model is not None: # `litellm --model ` - response = await litellm.aspeech(**data) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "audio_speech: Invalid model name passed in model=" - + data.get("model", "") - }, - ) + llm_call = await route_request( + data=data, + route_type="aspeech", + llm_router=llm_router, + user_model=user_model, + ) + response = await llm_call ### ALERTING ### asyncio.create_task( @@ -4085,47 +3900,13 @@ async def audio_transcriptions( ) ## ROUTE TO CORRECT ENDPOINT ## - # skip router if user passed their key - if "api_key" in data: - response = await litellm.atranscription(**data) - elif ( - llm_router is not None and data["model"] in router_model_names - ): # model in router model list - response = await llm_router.atranscription(**data) - - elif ( - llm_router is not None and data["model"] in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.atranscription( - **data, specific_deployment=True - ) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data["model"] in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.atranscription( - **data - ) # ensure this goes the llm_router, router will do the correct alias mapping - elif ( - llm_router is not None - and data["model"] not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.atranscription(**data) - elif user_model is not None: # `litellm --model ` - response = await litellm.atranscription(**data) - else: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={ - "error": "audio_transcriptions: Invalid model name passed in model=" - + data.get("model", "") - }, - ) + llm_call = await route_request( + data=data, + route_type="atranscription", + llm_router=llm_router, + user_model=user_model, + ) + response = await llm_call except Exception as e: raise HTTPException(status_code=500, detail=str(e)) finally: @@ -5341,40 +5122,13 @@ async def moderations( start_time = time.time() ## ROUTE TO CORRECT ENDPOINT ## - # skip router if user passed their key - if "api_key" in data: - response = await litellm.amoderation(**data) - elif ( - llm_router is not None and data.get("model") in router_model_names - ): # model in router model list - response = await llm_router.amoderation(**data) - elif ( - llm_router is not None and data.get("model") in llm_router.deployment_names - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.amoderation(**data, specific_deployment=True) - elif ( - llm_router is not None - and llm_router.model_group_alias is not None - and data.get("model") in llm_router.model_group_alias - ): # model set in model_group_alias - response = await llm_router.amoderation( - **data - ) # ensure this goes the llm_router, router will do the correct alias mapping - elif ( - llm_router is not None - and data.get("model") not in router_model_names - and ( - llm_router.default_deployment is not None - or len(llm_router.provider_default_deployments) > 0 - ) - ): # model in router deployments, calling a specific deployment on the router - response = await llm_router.amoderation(**data) - elif user_model is not None: # `litellm --model ` - response = await litellm.amoderation(**data) - else: - # /moderations does not need a "model" passed - # see https://platform.openai.com/docs/api-reference/moderations - response = await litellm.amoderation(**data) + llm_call = await route_request( + data=data, + route_type="amoderation", + llm_router=llm_router, + user_model=user_model, + ) + response = await llm_call ### ALERTING ### asyncio.create_task( diff --git a/litellm/proxy/route_llm_request.py b/litellm/proxy/route_llm_request.py new file mode 100644 index 0000000000..7a7be55b22 --- /dev/null +++ b/litellm/proxy/route_llm_request.py @@ -0,0 +1,117 @@ +from typing import TYPE_CHECKING, Any, Literal, Optional, Union + +from fastapi import ( + Depends, + FastAPI, + File, + Form, + Header, + HTTPException, + Path, + Request, + Response, + UploadFile, + status, +) + +import litellm +from litellm._logging import verbose_logger + +if TYPE_CHECKING: + from litellm.router import Router as _Router + + LitellmRouter = _Router +else: + LitellmRouter = Any + + +ROUTE_ENDPOINT_MAPPING = { + "acompletion": "/chat/completions", + "atext_completion": "/completions", + "aembedding": "/embeddings", + "aimage_generation": "/image/generations", + "aspeech": "/audio/speech", + "atranscription": "/audio/transcriptions", + "amoderation": "/moderations", +} + + +async def route_request( + data: dict, + llm_router: Optional[LitellmRouter], + user_model: Optional[str], + route_type: Literal[ + "acompletion", + "atext_completion", + "aembedding", + "aimage_generation", + "aspeech", + "atranscription", + "amoderation", + ], +): + """ + Common helper to route the request + + """ + router_model_names = llm_router.model_names if llm_router is not None else [] + + if "api_key" in data: + return getattr(litellm, f"{route_type}")(**data) + + elif "user_config" in data: + router_config = data.pop("user_config") + user_router = litellm.Router(**router_config) + return getattr(user_router, f"{route_type}")(**data) + + elif ( + route_type == "acompletion" + and data.get("model", "") is not None + and "," in data.get("model", "") + and llm_router is not None + ): + if data.get("fastest_response", False): + return llm_router.abatch_completion_fastest_response(**data) + else: + models = [model.strip() for model in data.pop("model").split(",")] + return llm_router.abatch_completion(models=models, **data) + + elif llm_router is not None: + if ( + data["model"] in router_model_names + or data["model"] in llm_router.get_model_ids() + ): + return getattr(llm_router, f"{route_type}")(**data) + + elif ( + llm_router.model_group_alias is not None + and data["model"] in llm_router.model_group_alias + ): + return getattr(llm_router, f"{route_type}")(**data) + + elif data["model"] in llm_router.deployment_names: + return getattr(llm_router, f"{route_type}")( + **data, specific_deployment=True + ) + + elif data["model"] not in router_model_names: + if llm_router.router_general_settings.pass_through_all_models: + return getattr(litellm, f"{route_type}")(**data) + elif ( + llm_router.default_deployment is not None + or len(llm_router.provider_default_deployments) > 0 + ): + return getattr(llm_router, f"{route_type}")(**data) + + elif user_model is not None: + return getattr(litellm, f"{route_type}")(**data) + + # if no route found then it's a bad request + route_name = ROUTE_ENDPOINT_MAPPING.get(route_type, route_type) + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": f"{route_name}: Invalid model name passed in model=" + + data.get("model", "") + }, + ) diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index cd7004e41d..6a28d70b17 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -21,6 +21,8 @@ def get_logging_payload( if kwargs is None: kwargs = {} + if response_obj is None: + response_obj = {} # standardize this function to be used across, s3, dynamoDB, langfuse logging litellm_params = kwargs.get("litellm_params", {}) metadata = ( diff --git a/litellm/proxy/tests/test_anthropic_context_caching.py b/litellm/proxy/tests/test_anthropic_context_caching.py new file mode 100644 index 0000000000..6156e4a048 --- /dev/null +++ b/litellm/proxy/tests/test_anthropic_context_caching.py @@ -0,0 +1,37 @@ +import openai + +client = openai.OpenAI( + api_key="sk-1234", # litellm proxy api key + base_url="http://0.0.0.0:4000", # litellm proxy base url +) + + +response = client.chat.completions.create( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + { # type: ignore + "role": "system", + "content": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" * 100, + "cache_control": {"type": "ephemeral"}, + }, + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement?", + }, + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, +) + +print(response) diff --git a/litellm/router_utils/client_initalization_utils.py b/litellm/router_utils/client_initalization_utils.py index 073a87901a..f396defb51 100644 --- a/litellm/router_utils/client_initalization_utils.py +++ b/litellm/router_utils/client_initalization_utils.py @@ -190,7 +190,7 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict): if azure_ad_token.startswith("oidc/"): azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) if api_version is None: - api_version = litellm.AZURE_DEFAULT_API_VERSION + api_version = os.getenv("AZURE_API_VERSION", litellm.AZURE_DEFAULT_API_VERSION) if "gateway.ai.cloudflare.com" in api_base: if not api_base.endswith("/"): diff --git a/litellm/tests/test_anthropic_prompt_caching.py b/litellm/tests/test_anthropic_prompt_caching.py new file mode 100644 index 0000000000..87bfc23f84 --- /dev/null +++ b/litellm/tests/test_anthropic_prompt_caching.py @@ -0,0 +1,321 @@ +import json +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import litellm +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.prompt_templates.factory import anthropic_messages_pt + +# litellm.num_retries =3 +litellm.cache = None +litellm.success_callback = [] +user_message = "Write a short poem about the sky" +messages = [{"content": user_message, "role": "user"}] + + +def logger_fn(user_model_dict): + print(f"user_model_dict: {user_model_dict}") + + +@pytest.fixture(autouse=True) +def reset_callbacks(): + print("\npytest fixture - resetting callbacks") + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm.callbacks = [] + + +@pytest.mark.asyncio +async def test_litellm_anthropic_prompt_caching_tools(): + # Arrange: Set up the MagicMock for the httpx.AsyncClient + mock_response = AsyncMock() + + def return_val(): + return { + "id": "msg_01XFDUDYJgAACzvnptvVoYEL", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 12, "output_tokens": 6}, + } + + mock_response.json = return_val + + litellm.set_verbose = True + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = await litellm.acompletion( + api_key="mock_api_key", + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + {"role": "user", "content": "What's the weather like in Boston today?"} + ], + tools=[ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + "cache_control": {"type": "ephemeral"}, + }, + } + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + expected_url = "https://api.anthropic.com/v1/messages" + expected_headers = { + "accept": "application/json", + "content-type": "application/json", + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + "x-api-key": "mock_api_key", + } + + expected_json = { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's the weather like in Boston today?", + } + ], + } + ], + "tools": [ + { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "cache_control": {"type": "ephemeral"}, + "input_schema": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"], + }, + }, + "required": ["location"], + }, + } + ], + "max_tokens": 4096, + "model": "claude-3-5-sonnet-20240620", + } + + mock_post.assert_called_once_with( + expected_url, json=expected_json, headers=expected_headers, timeout=600.0 + ) + + +@pytest.mark.asyncio() +async def test_anthropic_api_prompt_caching_basic(): + litellm.set_verbose = True + response = await litellm.acompletion( + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + # System Message + { + "role": "system", + "content": [ + { + "type": "text", + "text": "Here is the full text of a complex legal agreement" + * 400, + "cache_control": {"type": "ephemeral"}, + } + ], + }, + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + ], + temperature=0.2, + max_tokens=10, + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, + ) + + print("response=", response) + + assert "cache_read_input_tokens" in response.usage + assert "cache_creation_input_tokens" in response.usage + + # Assert either a cache entry was created or cache was read - changes depending on the anthropic api ttl + assert (response.usage.cache_read_input_tokens > 0) or ( + response.usage.cache_creation_input_tokens > 0 + ) + + +@pytest.mark.asyncio +async def test_litellm_anthropic_prompt_caching_system(): + # https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#prompt-caching-examples + # LArge Context Caching Example + mock_response = AsyncMock() + + def return_val(): + return { + "id": "msg_01XFDUDYJgAACzvnptvVoYEL", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}], + "model": "claude-3-5-sonnet-20240620", + "stop_reason": "end_turn", + "stop_sequence": None, + "usage": {"input_tokens": 12, "output_tokens": 6}, + } + + mock_response.json = return_val + + litellm.set_verbose = True + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = await litellm.acompletion( + api_key="mock_api_key", + model="anthropic/claude-3-5-sonnet-20240620", + messages=[ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + { + "type": "text", + "text": "Here is the full text of a complex legal agreement", + "cache_control": {"type": "ephemeral"}, + }, + ], + }, + { + "role": "user", + "content": "what are the key terms and conditions in this agreement?", + }, + ], + extra_headers={ + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + }, + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + expected_url = "https://api.anthropic.com/v1/messages" + expected_headers = { + "accept": "application/json", + "content-type": "application/json", + "anthropic-version": "2023-06-01", + "anthropic-beta": "prompt-caching-2024-07-31", + "x-api-key": "mock_api_key", + } + + expected_json = { + "system": [ + { + "type": "text", + "text": "You are an AI assistant tasked with analyzing legal documents.", + }, + { + "type": "text", + "text": "Here is the full text of a complex legal agreement", + "cache_control": {"type": "ephemeral"}, + }, + ], + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "what are the key terms and conditions in this agreement?", + } + ], + } + ], + "max_tokens": 4096, + "model": "claude-3-5-sonnet-20240620", + } + + mock_post.assert_called_once_with( + expected_url, json=expected_json, headers=expected_headers, timeout=600.0 + ) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 4da18144d0..c331021213 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -1159,8 +1159,8 @@ def test_bedrock_tools_pt_invalid_names(): assert result[1]["toolSpec"]["name"] == "another_invalid_name" -def test_bad_request_error(): - with pytest.raises(litellm.BadRequestError): +def test_not_found_error(): + with pytest.raises(litellm.NotFoundError): completion( model="bedrock/bad_model", messages=[ diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index db0239ca33..654b210ff7 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -14,7 +14,7 @@ sys.path.insert( ) # Adds the parent directory to the system path import os -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -23,7 +23,7 @@ from litellm import RateLimitError, Timeout, completion, completion_cost, embedd from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.prompt_templates.factory import anthropic_messages_pt -# litellm.num_retries = 3 +# litellm.num_retries =3 litellm.cache = None litellm.success_callback = [] user_message = "Write a short poem about the sky" @@ -190,6 +190,31 @@ def test_completion_azure_command_r(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.parametrize( + "api_base", + [ + "https://litellm8397336933.openai.azure.com", + "https://litellm8397336933.openai.azure.com/openai/deployments/gpt-4o/chat/completions?api-version=2023-03-15-preview", + ], +) +def test_completion_azure_ai_gpt_4o(api_base): + try: + litellm.set_verbose = True + + response = completion( + model="azure_ai/gpt-4o", + api_base=api_base, + api_key=os.getenv("AZURE_AI_OPENAI_KEY"), + messages=[{"role": "user", "content": "What is the meaning of life?"}], + ) + + print(response) + except litellm.Timeout as e: + pass + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + @pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_completion_databricks(sync_mode): @@ -3312,108 +3337,6 @@ def test_customprompt_together_ai(): # test_customprompt_together_ai() -@pytest.mark.skip(reason="AWS Suspended Account") -def test_completion_sagemaker(): - try: - litellm.set_verbose = True - print("testing sagemaker") - response = completion( - model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233", - model_id="huggingface-llm-mistral-7b-instruct-20240329-150233", - messages=messages, - temperature=0.2, - max_tokens=80, - aws_region_name=os.getenv("AWS_REGION_NAME_2"), - aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"), - aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"), - input_cost_per_second=0.000420, - ) - # Add any assertions here to check the response - print(response) - cost = completion_cost(completion_response=response) - print("calculated cost", cost) - assert ( - cost > 0.0 and cost < 1.0 - ) # should never be > $1 for a single completion call - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -# test_completion_sagemaker() - - -@pytest.mark.skip(reason="AWS Suspended Account") -@pytest.mark.asyncio -async def test_acompletion_sagemaker(): - try: - litellm.set_verbose = True - print("testing sagemaker") - response = await litellm.acompletion( - model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-ins-20240329-150233", - model_id="huggingface-llm-mistral-7b-instruct-20240329-150233", - messages=messages, - temperature=0.2, - max_tokens=80, - aws_region_name=os.getenv("AWS_REGION_NAME_2"), - aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID_2"), - aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY_2"), - input_cost_per_second=0.000420, - ) - # Add any assertions here to check the response - print(response) - cost = completion_cost(completion_response=response) - print("calculated cost", cost) - assert ( - cost > 0.0 and cost < 1.0 - ) # should never be > $1 for a single completion call - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -@pytest.mark.skip(reason="AWS Suspended Account") -def test_completion_chat_sagemaker(): - try: - messages = [{"role": "user", "content": "Hey, how's it going?"}] - litellm.set_verbose = True - response = completion( - model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4", - messages=messages, - max_tokens=100, - temperature=0.7, - stream=True, - ) - # Add any assertions here to check the response - complete_response = "" - for chunk in response: - complete_response += chunk.choices[0].delta.content or "" - print(f"complete_response: {complete_response}") - assert len(complete_response) > 0 - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - -# test_completion_chat_sagemaker() - - -@pytest.mark.skip(reason="AWS Suspended Account") -def test_completion_chat_sagemaker_mistral(): - try: - messages = [{"role": "user", "content": "Hey, how's it going?"}] - - response = completion( - model="sagemaker/jumpstart-dft-hf-llm-mistral-7b-instruct", - messages=messages, - max_tokens=100, - ) - # Add any assertions here to check the response - print(response) - except Exception as e: - pytest.fail(f"An error occurred: {str(e)}") - - -# test_completion_chat_sagemaker_mistral() - - def response_format_tests(response: litellm.ModelResponse): assert isinstance(response.id, str) assert response.id != "" @@ -3449,7 +3372,6 @@ def response_format_tests(response: litellm.ModelResponse): assert isinstance(response.usage.total_tokens, int) # type: ignore -@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.parametrize( "model", [ @@ -3463,6 +3385,7 @@ def response_format_tests(response: litellm.ModelResponse): "cohere.command-text-v14", ], ) +@pytest.mark.parametrize("sync_mode", [True, False]) @pytest.mark.asyncio async def test_completion_bedrock_httpx_models(sync_mode, model): litellm.set_verbose = True @@ -3705,19 +3628,21 @@ def test_completion_anyscale_api(): # test_completion_anyscale_api() -@pytest.mark.skip(reason="flaky test, times out frequently") +# @pytest.mark.skip(reason="flaky test, times out frequently") def test_completion_cohere(): try: # litellm.set_verbose=True messages = [ {"role": "system", "content": "You're a good bot"}, + {"role": "assistant", "content": [{"text": "2", "type": "text"}]}, + {"role": "assistant", "content": [{"text": "3", "type": "text"}]}, { "role": "user", "content": "Hey", }, ] response = completion( - model="command-nightly", + model="command-r", messages=messages, ) print(response) diff --git a/litellm/tests/test_function_call_parsing.py b/litellm/tests/test_function_call_parsing.py index d223a7c8f6..fab9cf110c 100644 --- a/litellm/tests/test_function_call_parsing.py +++ b/litellm/tests/test_function_call_parsing.py @@ -1,23 +1,27 @@ # What is this? ## Test to make sure function call response always works with json.loads() -> no extra parsing required. Relevant issue - https://github.com/BerriAI/litellm/issues/2654 -import sys, os +import os +import sys import traceback + from dotenv import load_dotenv load_dotenv() -import os, io +import io +import os sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path -import pytest -import litellm import json import warnings - -from litellm import completion from typing import List +import pytest + +import litellm +from litellm import completion + # Just a stub to keep the sample code simple class Trade: @@ -78,58 +82,60 @@ def trade(model_name: str) -> List[Trade]: }, } - response = completion( - model_name, - [ - { - "role": "system", - "content": """You are an expert asset manager, managing a portfolio. + try: + response = completion( + model_name, + [ + { + "role": "system", + "content": """You are an expert asset manager, managing a portfolio. - Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call: + Always use the `trade` function. Make sure that you call it correctly. For example, the following is a valid call: + ``` + trade({ + "orders": [ + {"action": "buy", "asset": "BTC", "amount": 0.1}, + {"action": "sell", "asset": "ETH", "amount": 0.2} + ] + }) + ``` + + If there are no trades to make, call `trade` with an empty array: + ``` + trade({ "orders": [] }) + ``` + """, + }, + { + "role": "user", + "content": """Manage the portfolio. + + Don't jabber. + + This is the current market data: ``` - trade({ - "orders": [ - {"action": "buy", "asset": "BTC", "amount": 0.1}, - {"action": "sell", "asset": "ETH", "amount": 0.2} - ] - }) + {market_data} ``` - If there are no trades to make, call `trade` with an empty array: + Your portfolio is as follows: ``` - trade({ "orders": [] }) + {portfolio} ``` - """, + """.replace( + "{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD" + ).replace( + "{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2" + ), + }, + ], + tools=[tool_spec], + tool_choice={ + "type": "function", + "function": {"name": tool_spec["function"]["name"]}, # type: ignore }, - { - "role": "user", - "content": """Manage the portfolio. - - Don't jabber. - - This is the current market data: - ``` - {market_data} - ``` - - Your portfolio is as follows: - ``` - {portfolio} - ``` - """.replace( - "{market_data}", "BTC: 64,000 USD\nETH: 3,500 USD" - ).replace( - "{portfolio}", "USD: 1000, BTC: 0.1, ETH: 0.2" - ), - }, - ], - tools=[tool_spec], - tool_choice={ - "type": "function", - "function": {"name": tool_spec["function"]["name"]}, # type: ignore - }, - ) - + ) + except litellm.InternalServerError: + pass calls = response.choices[0].message.tool_calls trades = [trade for call in calls for trade in parse_call(call)] return trades diff --git a/litellm/tests/test_gcs_bucket.py b/litellm/tests/test_gcs_bucket.py index c21988c73d..f0aaf8d8dd 100644 --- a/litellm/tests/test_gcs_bucket.py +++ b/litellm/tests/test_gcs_bucket.py @@ -147,6 +147,117 @@ async def test_basic_gcs_logger(): assert gcs_payload["response_cost"] > 0.0 + assert gcs_payload["log_event_type"] == "successful_api_call" + gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"]) + + assert ( + gcs_payload["spend_log_metadata"]["user_api_key"] + == "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b" + ) + assert ( + gcs_payload["spend_log_metadata"]["user_api_key_user_id"] + == "116544810872468347480" + ) + + # Delete Object from GCS + print("deleting object from GCS") + await gcs_logger.delete_gcs_object(object_name=object_name) + + +@pytest.mark.asyncio +async def test_basic_gcs_logger_failure(): + load_vertex_ai_credentials() + gcs_logger = GCSBucketLogger() + print("GCSBucketLogger", gcs_logger) + + gcs_log_id = f"failure-test-{uuid.uuid4().hex}" + + litellm.callbacks = [gcs_logger] + + try: + response = await litellm.acompletion( + model="gpt-3.5-turbo", + temperature=0.7, + messages=[{"role": "user", "content": "This is a test"}], + max_tokens=10, + user="ishaan-2", + mock_response=litellm.BadRequestError( + model="gpt-3.5-turbo", + message="Error: 400: Bad Request: Invalid API key, please check your API key and try again.", + llm_provider="openai", + ), + metadata={ + "gcs_log_id": gcs_log_id, + "tags": ["model-anthropic-claude-v2.1", "app-ishaan-prod"], + "user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b", + "user_api_key_alias": None, + "user_api_end_user_max_budget": None, + "litellm_api_version": "0.0.0", + "global_max_parallel_requests": None, + "user_api_key_user_id": "116544810872468347480", + "user_api_key_org_id": None, + "user_api_key_team_id": None, + "user_api_key_team_alias": None, + "user_api_key_metadata": {}, + "requester_ip_address": "127.0.0.1", + "spend_logs_metadata": {"hello": "world"}, + "headers": { + "content-type": "application/json", + "user-agent": "PostmanRuntime/7.32.3", + "accept": "*/*", + "postman-token": "92300061-eeaa-423b-a420-0b44896ecdc4", + "host": "localhost:4000", + "accept-encoding": "gzip, deflate, br", + "connection": "keep-alive", + "content-length": "163", + }, + "endpoint": "http://localhost:4000/chat/completions", + "model_group": "gpt-3.5-turbo", + "deployment": "azure/chatgpt-v-2", + "model_info": { + "id": "4bad40a1eb6bebd1682800f16f44b9f06c52a6703444c99c7f9f32e9de3693b4", + "db_model": False, + }, + "api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/", + "caching_groups": None, + "raw_request": "\n\nPOST Request Sent from LiteLLM:\ncurl -X POST \\\nhttps://openai-gpt-4-test-v-1.openai.azure.com//openai/ \\\n-H 'Authorization: *****' \\\n-d '{'model': 'chatgpt-v-2', 'messages': [{'role': 'system', 'content': 'you are a helpful assistant.\\n'}, {'role': 'user', 'content': 'bom dia'}], 'stream': False, 'max_tokens': 10, 'user': '116544810872468347480', 'extra_body': {}}'\n", + }, + ) + except: + pass + + await asyncio.sleep(5) + + # Get the current date + # Get the current date + current_date = datetime.now().strftime("%Y-%m-%d") + + # Modify the object_name to include the date-based folder + object_name = gcs_log_id + + print("object_name", object_name) + + # Check if object landed on GCS + object_from_gcs = await gcs_logger.download_gcs_object(object_name=object_name) + print("object from gcs=", object_from_gcs) + # convert object_from_gcs from bytes to DICT + parsed_data = json.loads(object_from_gcs) + print("object_from_gcs as dict", parsed_data) + + print("type of object_from_gcs", type(parsed_data)) + + gcs_payload = GCSBucketPayload(**parsed_data) + + print("gcs_payload", gcs_payload) + + assert gcs_payload["request_kwargs"]["model"] == "gpt-3.5-turbo" + assert gcs_payload["request_kwargs"]["messages"] == [ + {"role": "user", "content": "This is a test"} + ] + + assert gcs_payload["response_cost"] == 0 + assert gcs_payload["log_event_type"] == "failed_api_call" + gcs_payload["spend_log_metadata"] = json.loads(gcs_payload["spend_log_metadata"]) assert ( diff --git a/litellm/tests/test_prometheus.py b/litellm/tests/test_prometheus.py index 64e824e6db..7574beb9d9 100644 --- a/litellm/tests/test_prometheus.py +++ b/litellm/tests/test_prometheus.py @@ -76,6 +76,6 @@ async def test_async_prometheus_success_logging(): print("metrics from prometheus", metrics) assert metrics["litellm_requests_metric_total"] == 1.0 assert metrics["litellm_total_tokens_total"] == 30.0 - assert metrics["llm_deployment_success_responses_total"] == 1.0 - assert metrics["llm_deployment_total_requests_total"] == 1.0 - assert metrics["llm_deployment_latency_per_output_token_bucket"] == 1.0 + assert metrics["litellm_deployment_success_responses_total"] == 1.0 + assert metrics["litellm_deployment_total_requests_total"] == 1.0 + assert metrics["litellm_deployment_latency_per_output_token_bucket"] == 1.0 diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index f7a715a220..93e92a7926 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -260,3 +260,56 @@ def test_anthropic_messages_tool_call(): translated_messages[-1]["content"][0]["tool_use_id"] == "bc8cb4b6-88c4-4138-8993-3a9d9cd51656" ) + + +def test_anthropic_cache_controls_pt(): + "see anthropic docs for this: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching#continuing-a-multi-turn-conversation" + messages = [ + # marked for caching with the cache_control parameter, so that this checkpoint can read from the previous cache. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + }, + # The final turn is marked with cache-control, for continuing in followups. + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What are the key terms and conditions in this agreement?", + "cache_control": {"type": "ephemeral"}, + } + ], + }, + { + "role": "assistant", + "content": "Certainly! the key terms and conditions are the following: the contract is 1 year long for $10/mo", + "cache_control": {"type": "ephemeral"}, + }, + ] + + translated_messages = anthropic_messages_pt( + messages, model="claude-3-5-sonnet-20240620", llm_provider="anthropic" + ) + + for i, msg in enumerate(translated_messages): + if i == 0: + assert msg["content"][0]["cache_control"] == {"type": "ephemeral"} + elif i == 1: + assert "cache_controls" not in msg["content"][0] + elif i == 2: + assert msg["content"][0]["cache_control"] == {"type": "ephemeral"} + elif i == 3: + assert msg["content"][0]["cache_control"] == {"type": "ephemeral"} + + print("translated_messages: ", translated_messages) diff --git a/litellm/tests/test_provider_specific_config.py b/litellm/tests/test_provider_specific_config.py index c20c44fb13..a7765f658c 100644 --- a/litellm/tests/test_provider_specific_config.py +++ b/litellm/tests/test_provider_specific_config.py @@ -2,16 +2,19 @@ # This tests setting provider specific configs across providers # There are 2 types of tests - changing config dynamically or by setting class variables -import sys, os +import os +import sys import traceback + import pytest sys.path.insert( 0, os.path.abspath("../..") ) # Adds the parent directory to the system path +from unittest.mock import AsyncMock, MagicMock, patch + import litellm -from litellm import completion -from litellm import RateLimitError +from litellm import RateLimitError, completion # Huggingface - Expensive to deploy models and keep them running. Maybe we can try doing this via baseten?? # def hf_test_completion_tgi(): @@ -513,102 +516,165 @@ def sagemaker_test_completion(): # sagemaker_test_completion() -def test_sagemaker_default_region(mocker): +def test_sagemaker_default_region(): """ If no regions are specified in config or in environment, the default region is us-west-2 """ - mock_client = mocker.patch("boto3.client") - try: + mock_response = MagicMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + return_value=mock_response, + ) as mock_post: response = litellm.completion( model="sagemaker/mock-endpoint", - messages=[ - { - "content": "Hello, world!", - "role": "user" - } - ] + messages=[{"content": "Hello, world!", "role": "user"}], ) - except Exception: - pass # expected serialization exception because AWS client was replaced with a Mock - assert mock_client.call_args.kwargs["region_name"] == "us-west-2" + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + print("url=", kwargs["url"]) + + assert ( + kwargs["url"] + == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/mock-endpoint/invocations" + ) + # test_sagemaker_default_region() -def test_sagemaker_environment_region(mocker): +def test_sagemaker_environment_region(): """ If a region is specified in the environment, use that region instead of us-west-2 """ expected_region = "us-east-1" os.environ["AWS_REGION_NAME"] = expected_region - mock_client = mocker.patch("boto3.client") - try: + mock_response = MagicMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + return_value=mock_response, + ) as mock_post: response = litellm.completion( model="sagemaker/mock-endpoint", - messages=[ - { - "content": "Hello, world!", - "role": "user" - } - ] + messages=[{"content": "Hello, world!", "role": "user"}], ) - except Exception: - pass # expected serialization exception because AWS client was replaced with a Mock - del os.environ["AWS_REGION_NAME"] # cleanup - assert mock_client.call_args.kwargs["region_name"] == expected_region + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + print("url=", kwargs["url"]) + + assert ( + kwargs["url"] + == f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations" + ) + + del os.environ["AWS_REGION_NAME"] # cleanup + # test_sagemaker_environment_region() -def test_sagemaker_config_region(mocker): +def test_sagemaker_config_region(): """ If a region is specified as part of the optional parameters of the completion, including as part of the config file, then use that region instead of us-west-2 """ expected_region = "us-east-1" - mock_client = mocker.patch("boto3.client") - try: - response = litellm.completion( - model="sagemaker/mock-endpoint", - messages=[ + mock_response = MagicMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ { - "content": "Hello, world!", - "role": "user" + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", } ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + return_value=mock_response, + ) as mock_post: + + response = litellm.completion( + model="sagemaker/mock-endpoint", + messages=[{"content": "Hello, world!", "role": "user"}], aws_region_name=expected_region, ) - except Exception: - pass # expected serialization exception because AWS client was replaced with a Mock - assert mock_client.call_args.kwargs["region_name"] == expected_region + + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + print("url=", kwargs["url"]) + + assert ( + kwargs["url"] + == f"https://runtime.sagemaker.{expected_region}.amazonaws.com/endpoints/mock-endpoint/invocations" + ) + # test_sagemaker_config_region() -def test_sagemaker_config_and_environment_region(mocker): - """ - If both the environment and config file specify a region, the environment region is expected - """ - expected_region = "us-east-1" - unexpected_region = "us-east-2" - os.environ["AWS_REGION_NAME"] = expected_region - mock_client = mocker.patch("boto3.client") - try: - response = litellm.completion( - model="sagemaker/mock-endpoint", - messages=[ - { - "content": "Hello, world!", - "role": "user" - } - ], - aws_region_name=unexpected_region, - ) - except Exception: - pass # expected serialization exception because AWS client was replaced with a Mock - del os.environ["AWS_REGION_NAME"] # cleanup - assert mock_client.call_args.kwargs["region_name"] == expected_region - # test_sagemaker_config_and_environment_region() diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index a774d1b0ef..89b4cd926e 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -229,8 +229,9 @@ def test_chat_completion_exception_any_model(client): ) assert isinstance(openai_exception, openai.BadRequestError) _error_message = openai_exception.message - assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str( - _error_message + assert ( + "/chat/completions: Invalid model name passed in model=Lite-GPT-12" + in str(_error_message) ) except Exception as e: @@ -259,7 +260,7 @@ def test_embedding_exception_any_model(client): print("Exception raised=", openai_exception) assert isinstance(openai_exception, openai.BadRequestError) _error_message = openai_exception.message - assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str( + assert "/embeddings: Invalid model name passed in model=Lite-GPT-12" in str( _error_message ) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 757eef6d62..9a1c091267 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -966,3 +966,203 @@ async def test_user_info_team_list(prisma_client): pass mock_client.assert_called() + + +@pytest.mark.skip(reason="Local test") +@pytest.mark.asyncio +async def test_add_callback_via_key(prisma_client): + """ + Test if callback specified in key, is used. + """ + global headers + import json + + from fastapi import HTTPException, Request, Response + from starlette.datastructures import URL + + from litellm.proxy.proxy_server import chat_completion + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + litellm.set_verbose = True + + try: + # Your test data + test_data = { + "model": "azure/chatgpt-v-2", + "messages": [ + {"role": "user", "content": "write 1 sentence poem"}, + ], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + } + + request = Request(scope={"type": "http", "method": "POST", "headers": {}}) + request._url = URL(url="/chat/completions") + + json_bytes = json.dumps(test_data).encode("utf-8") + + request._body = json_bytes + + with patch.object( + litellm.litellm_core_utils.litellm_logging, + "LangFuseLogger", + new=MagicMock(), + ) as mock_client: + resp = await chat_completion( + request=request, + fastapi_response=Response(), + user_api_key_dict=UserAPIKeyAuth( + metadata={ + "logging": [ + { + "callback_name": "langfuse", # 'otel', 'langfuse', 'lunary' + "callback_type": "success", # set, if required by integration - future improvement, have logging tools work for success + failure by default + "callback_vars": { + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", + "langfuse_host": "https://us.cloud.langfuse.com", + }, + } + ] + } + ), + ) + print(resp) + mock_client.assert_called() + mock_client.return_value.log_event.assert_called() + args, kwargs = mock_client.return_value.log_event.call_args + kwargs = kwargs["kwargs"] + assert "user_api_key_metadata" in kwargs["litellm_params"]["metadata"] + assert ( + "logging" + in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"] + ) + checked_keys = False + for item in kwargs["litellm_params"]["metadata"]["user_api_key_metadata"][ + "logging" + ]: + for k, v in item["callback_vars"].items(): + print("k={}, v={}".format(k, v)) + if "key" in k: + assert "os.environ" in v + checked_keys = True + + assert checked_keys + except Exception as e: + pytest.fail(f"LiteLLM Proxy test failed. Exception - {str(e)}") + + +@pytest.mark.asyncio +async def test_add_callback_via_key_litellm_pre_call_utils(prisma_client): + import json + + from fastapi import HTTPException, Request, Response + from starlette.datastructures import URL + + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config") + + request = Request(scope={"type": "http", "method": "POST", "headers": {}}) + request._url = URL(url="/chat/completions") + + test_data = { + "model": "azure/chatgpt-v-2", + "messages": [ + {"role": "user", "content": "write 1 sentence poem"}, + ], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + } + + json_bytes = json.dumps(test_data).encode("utf-8") + + request._body = json_bytes + + data = { + "data": { + "model": "azure/chatgpt-v-2", + "messages": [{"role": "user", "content": "write 1 sentence poem"}], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + }, + "request": request, + "user_api_key_dict": UserAPIKeyAuth( + token=None, + key_name=None, + key_alias=None, + spend=0.0, + max_budget=None, + expires=None, + models=[], + aliases={}, + config={}, + user_id=None, + team_id=None, + max_parallel_requests=None, + metadata={ + "logging": [ + { + "callback_name": "langfuse", + "callback_type": "success", + "callback_vars": { + "langfuse_public_key": "os.environ/LANGFUSE_PUBLIC_KEY", + "langfuse_secret_key": "os.environ/LANGFUSE_SECRET_KEY", + "langfuse_host": "https://us.cloud.langfuse.com", + }, + } + ] + }, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + allowed_cache_controls=[], + permissions={}, + model_spend={}, + model_max_budget={}, + soft_budget_cooldown=False, + litellm_budget_table=None, + org_id=None, + team_spend=None, + team_alias=None, + team_tpm_limit=None, + team_rpm_limit=None, + team_max_budget=None, + team_models=[], + team_blocked=False, + soft_budget=None, + team_model_aliases=None, + team_member_spend=None, + team_metadata=None, + end_user_id=None, + end_user_tpm_limit=None, + end_user_rpm_limit=None, + end_user_max_budget=None, + last_refreshed_at=None, + api_key=None, + user_role=None, + allowed_model_region=None, + parent_otel_span=None, + ), + "proxy_config": proxy_config, + "general_settings": {}, + "version": "0.0.0", + } + + new_data = await add_litellm_data_to_request(**data) + + assert "success_callback" in new_data + assert new_data["success_callback"] == ["langfuse"] + assert "langfuse_public_key" in new_data + assert "langfuse_secret_key" in new_data diff --git a/litellm/tests/test_sagemaker.py b/litellm/tests/test_sagemaker.py new file mode 100644 index 0000000000..b6b4251c6b --- /dev/null +++ b/litellm/tests/test_sagemaker.py @@ -0,0 +1,316 @@ +import json +import os +import sys +import traceback + +from dotenv import load_dotenv + +load_dotenv() +import io +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import litellm +from litellm import RateLimitError, Timeout, completion, completion_cost, embedding +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.llms.prompt_templates.factory import anthropic_messages_pt + +# litellm.num_retries =3 +litellm.cache = None +litellm.success_callback = [] +user_message = "Write a short poem about the sky" +messages = [{"content": user_message, "role": "user"}] +import logging + +from litellm._logging import verbose_logger + + +def logger_fn(user_model_dict): + print(f"user_model_dict: {user_model_dict}") + + +@pytest.fixture(autouse=True) +def reset_callbacks(): + print("\npytest fixture - resetting callbacks") + litellm.success_callback = [] + litellm._async_success_callback = [] + litellm.failure_callback = [] + litellm.callbacks = [] + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [True, False]) +async def test_completion_sagemaker(sync_mode): + try: + litellm.set_verbose = True + verbose_logger.setLevel(logging.DEBUG) + print("testing sagemaker") + if sync_mode is True: + response = litellm.completion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + else: + response = await litellm.acompletion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + # Add any assertions here to check the response + print(response) + cost = completion_cost(completion_response=response) + print("calculated cost", cost) + assert ( + cost > 0.0 and cost < 1.0 + ) # should never be > $1 for a single completion call + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio() +@pytest.mark.parametrize("sync_mode", [False, True]) +async def test_completion_sagemaker_stream(sync_mode): + try: + litellm.set_verbose = False + print("testing sagemaker") + verbose_logger.setLevel(logging.DEBUG) + full_text = "" + if sync_mode is True: + response = litellm.completion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi - what is ur name"}, + ], + temperature=0.2, + stream=True, + max_tokens=80, + input_cost_per_second=0.000420, + ) + + for chunk in response: + print(chunk) + full_text += chunk.choices[0].delta.content or "" + + print("SYNC RESPONSE full text", full_text) + else: + response = await litellm.acompletion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi - what is ur name"}, + ], + stream=True, + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + + print("streaming response") + + async for chunk in response: + print(chunk) + full_text += chunk.choices[0].delta.content or "" + + print("ASYNC RESPONSE full text", full_text) + + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_acompletion_sagemaker_non_stream(): + mock_response = AsyncMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "inputs": "hi", + "parameters": {"temperature": 0.2, "max_new_tokens": 80}, + } + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = await litellm.acompletion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + assert args_to_sagemaker == expected_payload + assert ( + kwargs["url"] + == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" + ) + + +@pytest.mark.asyncio +async def test_completion_sagemaker_non_stream(): + mock_response = MagicMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "inputs": "hi", + "parameters": {"temperature": 0.2, "max_new_tokens": 80}, + } + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = litellm.completion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + assert args_to_sagemaker == expected_payload + assert ( + kwargs["url"] + == "https://runtime.sagemaker.us-west-2.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" + ) + + +@pytest.mark.asyncio +async def test_completion_sagemaker_non_stream_with_aws_params(): + mock_response = MagicMock() + + def return_val(): + return { + "generated_text": "This is a mock response from SageMaker.", + "id": "cmpl-mockid", + "object": "text_completion", + "created": 1629800000, + "model": "sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + "choices": [ + { + "text": "This is a mock response from SageMaker.", + "index": 0, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": {"prompt_tokens": 1, "completion_tokens": 8, "total_tokens": 9}, + } + + mock_response.json = return_val + mock_response.status_code = 200 + + expected_payload = { + "inputs": "hi", + "parameters": {"temperature": 0.2, "max_new_tokens": 80}, + } + + with patch( + "litellm.llms.custom_httpx.http_handler.HTTPHandler.post", + return_value=mock_response, + ) as mock_post: + # Act: Call the litellm.acompletion function + response = litellm.completion( + model="sagemaker/jumpstart-dft-hf-textgeneration1-mp-20240815-185614", + messages=[ + {"role": "user", "content": "hi"}, + ], + temperature=0.2, + max_tokens=80, + input_cost_per_second=0.000420, + aws_access_key_id="gm", + aws_secret_access_key="s", + aws_region_name="us-west-5", + ) + + # Print what was called on the mock + print("call args=", mock_post.call_args) + + # Assert + mock_post.assert_called_once() + _, kwargs = mock_post.call_args + args_to_sagemaker = kwargs["json"] + print("Arguments passed to sagemaker=", args_to_sagemaker) + assert args_to_sagemaker == expected_payload + assert ( + kwargs["url"] + == "https://runtime.sagemaker.us-west-5.amazonaws.com/endpoints/jumpstart-dft-hf-textgeneration1-mp-20240815-185614/invocations" + ) diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 025ea81200..3f42331879 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1683,6 +1683,7 @@ def test_completion_bedrock_mistral_stream(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.skip(reason="stopped using TokenIterator") def test_sagemaker_weird_response(): """ When the stream ends, flush any remaining holding chunks. diff --git a/litellm/tests/test_user_api_key_auth.py b/litellm/tests/test_user_api_key_auth.py index ad057ee572..e0595ac13c 100644 --- a/litellm/tests/test_user_api_key_auth.py +++ b/litellm/tests/test_user_api_key_auth.py @@ -44,7 +44,7 @@ def test_check_valid_ip( request = Request(client_ip) - assert _check_valid_ip(allowed_ips, request) == expected_result # type: ignore + assert _check_valid_ip(allowed_ips, request)[0] == expected_result # type: ignore # test x-forwarder for is used when user has opted in @@ -72,7 +72,7 @@ def test_check_valid_ip_sent_with_x_forwarded_for( request = Request(client_ip, headers={"X-Forwarded-For": client_ip}) - assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True) == expected_result # type: ignore + assert _check_valid_ip(allowed_ips, request, use_x_forwarded_for=True)[0] == expected_result # type: ignore @pytest.mark.asyncio diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index 36bcb6cc73..f14aa20c73 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -15,9 +15,10 @@ class AnthropicMessagesTool(TypedDict, total=False): input_schema: Required[dict] -class AnthropicMessagesTextParam(TypedDict): +class AnthropicMessagesTextParam(TypedDict, total=False): type: Literal["text"] text: str + cache_control: Optional[dict] class AnthropicMessagesToolUseParam(TypedDict): @@ -54,9 +55,10 @@ class AnthropicImageParamSource(TypedDict): data: str -class AnthropicMessagesImageParam(TypedDict): +class AnthropicMessagesImageParam(TypedDict, total=False): type: Literal["image"] source: AnthropicImageParamSource + cache_control: Optional[dict] class AnthropicMessagesToolResultContent(TypedDict): @@ -92,6 +94,12 @@ class AnthropicMetadata(TypedDict, total=False): user_id: str +class AnthropicSystemMessageContent(TypedDict, total=False): + type: str + text: str + cache_control: Optional[dict] + + class AnthropicMessagesRequest(TypedDict, total=False): model: Required[str] messages: Required[ @@ -106,7 +114,7 @@ class AnthropicMessagesRequest(TypedDict, total=False): metadata: AnthropicMetadata stop_sequences: List[str] stream: bool - system: str + system: Union[str, List] temperature: float tool_choice: AnthropicMessagesToolChoice tools: List[AnthropicMessagesTool] diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 0d67d5d602..5d2c416f9c 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -361,7 +361,7 @@ class ChatCompletionToolMessage(TypedDict): class ChatCompletionSystemMessage(TypedDict, total=False): role: Required[Literal["system"]] - content: Required[str] + content: Required[Union[str, List]] name: str diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 5cf6270868..519b301039 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -80,7 +80,7 @@ class ModelInfo(TypedDict, total=False): supports_assistant_prefill: Optional[bool] -class GenericStreamingChunk(TypedDict): +class GenericStreamingChunk(TypedDict, total=False): text: Required[str] tool_use: Optional[ChatCompletionToolCallChunk] is_finished: Required[bool] diff --git a/litellm/utils.py b/litellm/utils.py index 49528d0f77..0875b0e0e5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4479,7 +4479,22 @@ def _is_non_openai_azure_model(model: str) -> bool: or f"mistral/{model_name}" in litellm.mistral_chat_models ): return True - except: + except Exception: + return False + return False + + +def _is_azure_openai_model(model: str) -> bool: + try: + if "/" in model: + model = model.split("/", 1)[1] + if ( + model in litellm.open_ai_chat_completion_models + or model in litellm.open_ai_text_completion_models + or model in litellm.open_ai_embedding_models + ): + return True + except Exception: return False return False @@ -4613,6 +4628,14 @@ def get_llm_provider( elif custom_llm_provider == "azure_ai": api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore dynamic_api_key = api_key or get_secret("AZURE_AI_API_KEY") + + if _is_azure_openai_model(model=model): + verbose_logger.debug( + "Model={} is Azure OpenAI model. Setting custom_llm_provider='azure'.".format( + model + ) + ) + custom_llm_provider = "azure" elif custom_llm_provider == "github": api_base = api_base or get_secret("GITHUB_API_BASE") or "https://models.inference.ai.azure.com" # type: ignore dynamic_api_key = api_key or get_secret("GITHUB_API_KEY") @@ -9825,11 +9848,28 @@ class CustomStreamWrapper: completion_obj["tool_calls"] = [response_obj["tool_use"]] elif self.custom_llm_provider == "sagemaker": - print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") - response_obj = self.handle_sagemaker_stream(chunk) + from litellm.types.llms.bedrock import GenericStreamingChunk + + if self.received_finish_reason is not None: + raise StopIteration + response_obj: GenericStreamingChunk = chunk completion_obj["content"] = response_obj["text"] if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + + if ( + self.stream_options + and self.stream_options.get("include_usage", False) is True + and response_obj["usage"] is not None + ): + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"]["inputTokens"], + completion_tokens=response_obj["usage"]["outputTokens"], + total_tokens=response_obj["usage"]["totalTokens"], + ) + + if "tool_use" in response_obj and response_obj["tool_use"] is not None: + completion_obj["tool_calls"] = [response_obj["tool_use"]] elif self.custom_llm_provider == "petals": if len(self.completion_stream) == 0: if self.received_finish_reason is not None: diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 455fe1e3c5..d30270c5c8 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -57,6 +57,18 @@ "supports_parallel_function_calling": true, "supports_vision": true }, + "chatgpt-4o-latest": { + "max_tokens": 4096, + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000005, + "output_cost_per_token": 0.000015, + "litellm_provider": "openai", + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_vision": true + }, "gpt-4o-2024-05-13": { "max_tokens": 4096, "max_input_tokens": 128000, @@ -2062,7 +2074,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/claude-3-5-sonnet@20240620": { "max_tokens": 4096, @@ -2073,7 +2086,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/claude-3-haiku@20240307": { "max_tokens": 4096, @@ -2084,7 +2098,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/claude-3-opus@20240229": { "max_tokens": 4096, @@ -2095,7 +2110,8 @@ "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat", "supports_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_assistant_prefill": true }, "vertex_ai/meta/llama3-405b-instruct-maas": { "max_tokens": 32000, @@ -4519,6 +4535,69 @@ "litellm_provider": "perplexity", "mode": "chat" }, + "perplexity/llama-3.1-70b-instruct": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-8b-instruct": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-huge-128k-online": { + "max_tokens": 127072, + "max_input_tokens": 127072, + "max_output_tokens": 127072, + "input_cost_per_token": 0.000005, + "output_cost_per_token": 0.000005, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-large-128k-online": { + "max_tokens": 127072, + "max_input_tokens": 127072, + "max_output_tokens": 127072, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-large-128k-chat": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.000001, + "output_cost_per_token": 0.000001, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-small-128k-chat": { + "max_tokens": 131072, + "max_input_tokens": 131072, + "max_output_tokens": 131072, + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "perplexity", + "mode": "chat" + }, + "perplexity/llama-3.1-sonar-small-128k-online": { + "max_tokens": 127072, + "max_input_tokens": 127072, + "max_output_tokens": 127072, + "input_cost_per_token": 0.0000002, + "output_cost_per_token": 0.0000002, + "litellm_provider": "perplexity", + "mode": "chat" + }, "perplexity/pplx-7b-chat": { "max_tokens": 8192, "max_input_tokens": 8192, diff --git a/pyproject.toml b/pyproject.toml index ae9ba13da2..a7d069789d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.43.9" +version = "1.43.15" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.43.9" +version = "1.43.15" version_files = [ "pyproject.toml:^version" ]