mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
LITELLM: Remove requests
library usage (#7235)
* fix(generic_api_callback.py): remove requests lib usage * fix(budget_manager.py): remove requests lib usgae * fix(main.py): cleanup requests lib usage * fix(utils.py): remove requests lib usage * fix(argilla.py): fix argilla test * fix(athina.py): replace 'requests' lib usage with litellm module * fix(greenscale.py): replace 'requests' lib usage with httpx * fix: remove unused 'requests' lib import + replace usage in some places * fix(prompt_layer.py): remove 'requests' lib usage from prompt layer * fix(ollama_chat.py): remove 'requests' lib usage * fix(baseten.py): replace 'requests' lib usage * fix(codestral/): replace 'requests' lib usage * fix(predibase/): replace 'requests' lib usage * refactor: cleanup unused 'requests' lib imports * fix(oobabooga.py): cleanup 'requests' lib usage * fix(invoke_handler.py): remove unused 'requests' lib usage * refactor: cleanup unused 'requests' lib import * fix: fix linting errors * refactor(ollama/): move ollama to using base llm http handler removes 'requests' lib dep for ollama integration * fix(ollama_chat.py): fix linting errors * fix(ollama/completion/transformation.py): convert non-jpeg/png image to jpeg/png before passing to ollama
This commit is contained in:
parent
f628290ce7
commit
03e711e3e4
46 changed files with 523 additions and 612 deletions
|
@ -3,7 +3,6 @@
|
||||||
#### What this does ####
|
#### What this does ####
|
||||||
# On success, logs events to Promptlayer
|
# On success, logs events to Promptlayer
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
|
|
||||||
from litellm.proxy._types import UserAPIKeyAuth
|
from litellm.proxy._types import UserAPIKeyAuth
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
@ -17,7 +16,6 @@ import traceback
|
||||||
# On success + failure, log events to Supabase
|
# On success + failure, log events to Supabase
|
||||||
|
|
||||||
import dotenv, os
|
import dotenv, os
|
||||||
import requests
|
|
||||||
import traceback
|
import traceback
|
||||||
import datetime, subprocess, sys
|
import datetime, subprocess, sys
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
|
@ -116,7 +114,9 @@ class GenericAPILogger:
|
||||||
print_verbose(f"\nGeneric Logger - Logging payload = {data}")
|
print_verbose(f"\nGeneric Logger - Logging payload = {data}")
|
||||||
|
|
||||||
# make request to endpoint with payload
|
# make request to endpoint with payload
|
||||||
response = requests.post(self.endpoint, json=data, headers=self.headers)
|
response = litellm.module_level_client.post(
|
||||||
|
self.endpoint, json=data, headers=self.headers
|
||||||
|
)
|
||||||
|
|
||||||
response_status = response.status_code
|
response_status = response.status_code
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
|
|
|
@ -13,8 +13,6 @@ import threading
|
||||||
import time
|
import time
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import ModelResponse
|
from litellm.utils import ModelResponse
|
||||||
|
|
||||||
|
@ -58,7 +56,9 @@ class BudgetManager:
|
||||||
# Load the user_dict from hosted db
|
# Load the user_dict from hosted db
|
||||||
url = self.api_base + "/get_budget"
|
url = self.api_base + "/get_budget"
|
||||||
data = {"project_name": self.project_name}
|
data = {"project_name": self.project_name}
|
||||||
response = requests.post(url, headers=self.headers, json=data)
|
response = litellm.module_level_client.post(
|
||||||
|
url, headers=self.headers, json=data
|
||||||
|
)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
if response["status"] == "error":
|
if response["status"] == "error":
|
||||||
self.user_dict = (
|
self.user_dict = (
|
||||||
|
@ -215,6 +215,8 @@ class BudgetManager:
|
||||||
elif self.client_type == "hosted":
|
elif self.client_type == "hosted":
|
||||||
url = self.api_base + "/set_budget"
|
url = self.api_base + "/set_budget"
|
||||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||||
response = requests.post(url, headers=self.headers, json=data)
|
response = litellm.module_level_client.post(
|
||||||
|
url, headers=self.headers, json=data
|
||||||
|
)
|
||||||
response = response.json()
|
response = response.json()
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -15,19 +15,20 @@ from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
import dotenv # type: ignore
|
import dotenv # type: ignore
|
||||||
import httpx
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||||
|
get_content_from_model_response,
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
httpxSpecialProvider,
|
httpxSpecialProvider,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.prompt_templates.common_utils import get_content_from_model_response
|
|
||||||
from litellm.types.integrations.argilla import (
|
from litellm.types.integrations.argilla import (
|
||||||
SUPPORTED_PAYLOAD_FIELDS,
|
SUPPORTED_PAYLOAD_FIELDS,
|
||||||
ArgillaCredentialsObject,
|
ArgillaCredentialsObject,
|
||||||
|
@ -223,7 +224,7 @@ class ArgillaLogger(CustomBatchLogger):
|
||||||
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
headers = {"X-Argilla-Api-Key": argilla_api_key}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
url=url,
|
url=url,
|
||||||
json=self.log_queue,
|
json=self.log_queue,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
class AthinaLogger:
|
class AthinaLogger:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
@ -27,8 +29,6 @@ class AthinaLogger:
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_stream = kwargs.get("stream", False)
|
is_stream = kwargs.get("stream", False)
|
||||||
if is_stream:
|
if is_stream:
|
||||||
|
@ -81,7 +81,7 @@ class AthinaLogger:
|
||||||
if key in metadata:
|
if key in metadata:
|
||||||
data[key] = metadata[key]
|
data[key] = metadata[key]
|
||||||
|
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
self.athina_logging_url,
|
self.athina_logging_url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
data=json.dumps(data, default=str),
|
data=json.dumps(data, default=str),
|
||||||
|
|
|
@ -8,7 +8,6 @@ import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
import requests # type: ignore
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
class GreenscaleLogger:
|
class GreenscaleLogger:
|
||||||
|
@ -54,7 +54,7 @@ class GreenscaleLogger:
|
||||||
if self.greenscale_logging_url is None:
|
if self.greenscale_logging_url is None:
|
||||||
raise Exception("Greenscale Logger Error - No logging URL found")
|
raise Exception("Greenscale Logger Error - No logging URL found")
|
||||||
|
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
self.greenscale_logging_url,
|
self.greenscale_logging_url,
|
||||||
headers=self.headers,
|
headers=self.headers,
|
||||||
data=json.dumps(data, default=str),
|
data=json.dumps(data, default=str),
|
||||||
|
|
|
@ -4,7 +4,6 @@ import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
@ -179,7 +178,7 @@ class HeliconeLogger:
|
||||||
},
|
},
|
||||||
}, # {"seconds": .., "milliseconds": ..}
|
}, # {"seconds": .., "milliseconds": ..}
|
||||||
}
|
}
|
||||||
response = requests.post(url, headers=headers, json=data)
|
response = litellm.module_level_client.post(url, headers=headers, json=data)
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
print_verbose("Helicone Logging - Success!")
|
print_verbose("Helicone Logging - Success!")
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, List, Optional, TypedDict, Union
|
||||||
|
|
||||||
import dotenv # type: ignore
|
import dotenv # type: ignore
|
||||||
import httpx
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
from pydantic import BaseModel # type: ignore
|
from pydantic import BaseModel # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -481,7 +480,7 @@ class LangsmithLogger(CustomBatchLogger):
|
||||||
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
|
||||||
|
|
||||||
url = f"{langsmith_api_base}/runs/{run_id}"
|
url = f"{langsmith_api_base}/runs/{run_id}"
|
||||||
response = requests.get(
|
response = litellm.module_level_client.get(
|
||||||
url=url,
|
url=url,
|
||||||
headers={"x-api-key": langsmith_api_key},
|
headers={"x-api-key": langsmith_api_key},
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,9 +9,6 @@ import uuid
|
||||||
from datetime import date, datetime, timedelta
|
from datetime import date, datetime, timedelta
|
||||||
from typing import Optional, TypedDict, Union
|
from typing import Optional, TypedDict, Union
|
||||||
|
|
||||||
import dotenv
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
|
@ -11,9 +11,6 @@ import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import dotenv
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
from litellm.types.integrations.prometheus import LATENCY_BUCKETS
|
from litellm.types.integrations.prometheus import LATENCY_BUCKETS
|
||||||
|
|
|
@ -3,10 +3,10 @@
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import dotenv
|
|
||||||
import requests # type: ignore
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
class PromptLayerLogger:
|
class PromptLayerLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
@ -47,7 +47,7 @@ class PromptLayerLogger:
|
||||||
if isinstance(response_obj, BaseModel):
|
if isinstance(response_obj, BaseModel):
|
||||||
response_obj = response_obj.model_dump()
|
response_obj = response_obj.model_dump()
|
||||||
|
|
||||||
request_response = requests.post(
|
request_response = litellm.module_level_client.post(
|
||||||
"https://api.promptlayer.com/rest/track-request",
|
"https://api.promptlayer.com/rest/track-request",
|
||||||
json={
|
json={
|
||||||
"function_name": "openai.ChatCompletion.create",
|
"function_name": "openai.ChatCompletion.create",
|
||||||
|
@ -74,7 +74,7 @@ class PromptLayerLogger:
|
||||||
|
|
||||||
if "request_id" in response_json:
|
if "request_id" in response_json:
|
||||||
if metadata:
|
if metadata:
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
"https://api.promptlayer.com/rest/track-metadata",
|
"https://api.promptlayer.com/rest/track-metadata",
|
||||||
json={
|
json={
|
||||||
"request_id": response_json["request_id"],
|
"request_id": response_json["request_id"],
|
||||||
|
|
|
@ -8,7 +8,6 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
|
@ -177,8 +177,6 @@ import os
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import requests
|
|
||||||
|
|
||||||
|
|
||||||
class WeightsBiasesLogger:
|
class WeightsBiasesLogger:
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
|
|
|
@ -570,40 +570,6 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def handle_ollama_stream(self, chunk):
|
|
||||||
try:
|
|
||||||
if isinstance(chunk, dict):
|
|
||||||
json_chunk = chunk
|
|
||||||
else:
|
|
||||||
json_chunk = json.loads(chunk)
|
|
||||||
if "error" in json_chunk:
|
|
||||||
raise Exception(f"Ollama Error - {json_chunk}")
|
|
||||||
|
|
||||||
text = ""
|
|
||||||
is_finished = False
|
|
||||||
finish_reason = None
|
|
||||||
if json_chunk["done"] is True:
|
|
||||||
text = ""
|
|
||||||
is_finished = True
|
|
||||||
finish_reason = "stop"
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
elif json_chunk["response"]:
|
|
||||||
print_verbose(f"delta content: {json_chunk}")
|
|
||||||
text = json_chunk["response"]
|
|
||||||
return {
|
|
||||||
"text": text,
|
|
||||||
"is_finished": is_finished,
|
|
||||||
"finish_reason": finish_reason,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
raise Exception(f"Ollama Error - {json_chunk}")
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
def handle_ollama_chat_stream(self, chunk):
|
def handle_ollama_chat_stream(self, chunk):
|
||||||
# for ollama_chat/ provider
|
# for ollama_chat/ provider
|
||||||
try:
|
try:
|
||||||
|
@ -1111,12 +1077,6 @@ class CustomStreamWrapper:
|
||||||
new_chunk = self.completion_stream[:chunk_size]
|
new_chunk = self.completion_stream[:chunk_size]
|
||||||
completion_obj["content"] = new_chunk
|
completion_obj["content"] = new_chunk
|
||||||
self.completion_stream = self.completion_stream[chunk_size:]
|
self.completion_stream = self.completion_stream[chunk_size:]
|
||||||
elif self.custom_llm_provider == "ollama":
|
|
||||||
response_obj = self.handle_ollama_stream(chunk)
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
|
||||||
if response_obj["is_finished"]:
|
|
||||||
self.received_finish_reason = response_obj["finish_reason"]
|
|
||||||
elif self.custom_llm_provider == "ollama_chat":
|
elif self.custom_llm_provider == "ollama_chat":
|
||||||
response_obj = self.handle_ollama_chat_stream(chunk)
|
response_obj = self.handle_ollama_chat_stream(chunk)
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
|
|
@ -13,7 +13,6 @@ from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
from openai.types.chat.chat_completion_chunk import Choice as OpenAIStreamingChoice
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
|
|
@ -16,7 +16,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
from litellm.constants import RESPONSE_FORMAT_TOOL_NAME
|
||||||
|
|
|
@ -4,11 +4,14 @@ import uuid
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
|
||||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import OpenAIConfig
|
from litellm import OpenAIConfig
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
Choices,
|
Choices,
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
|
@ -22,7 +25,6 @@ from litellm.utils import (
|
||||||
from ...base import BaseLLM
|
from ...base import BaseLLM
|
||||||
from ...openai.completion.handler import OpenAITextCompletion
|
from ...openai.completion.handler import OpenAITextCompletion
|
||||||
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
from ...openai.completion.transformation import OpenAITextCompletionConfig
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
from ..common_utils import AzureOpenAIError
|
from ..common_utils import AzureOpenAIError
|
||||||
|
|
||||||
openai_text_completion_config = OpenAITextCompletionConfig()
|
openai_text_completion_config = OpenAITextCompletionConfig()
|
||||||
|
|
|
@ -4,9 +4,8 @@ import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
|
||||||
import requests # type: ignore
|
import litellm
|
||||||
|
from litellm.types.utils import ModelResponse, Usage
|
||||||
from litellm.utils import ModelResponse, Usage
|
|
||||||
|
|
||||||
|
|
||||||
class BasetenError(Exception):
|
class BasetenError(Exception):
|
||||||
|
@ -71,7 +70,7 @@ def completion(
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
completion_url_fragment_1 + model + completion_url_fragment_2,
|
completion_url_fragment_1 + model + completion_url_fragment_2,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
|
|
@ -10,8 +10,6 @@ from litellm._logging import verbose_logger
|
||||||
from litellm.caching.caching import DualCache, InMemoryCache
|
from litellm.caching.caching import DualCache, InMemoryCache
|
||||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||||
|
|
||||||
from litellm.llms.base import BaseLLM
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from botocore.credentials import Credentials
|
from botocore.credentials import Credentials
|
||||||
else:
|
else:
|
||||||
|
@ -37,7 +35,7 @@ class AwsAuthError(Exception):
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
|
||||||
class BaseAWSLLM(BaseLLM):
|
class BaseAWSLLM:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.iam_cache = DualCache()
|
self.iam_cache = DualCache()
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
|
@ -25,7 +25,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
@ -316,7 +315,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
def process_response( # noqa: PLR0915
|
def process_response( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: httpx.Response,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: Logging,
|
logging_obj: Logging,
|
||||||
|
@ -1041,9 +1040,6 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
def embedding(self, *args, **kwargs):
|
|
||||||
return super().embedding(*args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
global _response_stream_shape_cache
|
global _response_stream_shape_cache
|
||||||
|
|
|
@ -12,7 +12,6 @@ from functools import partial
|
||||||
from typing import Callable, List, Literal, Optional, Union
|
from typing import Callable, List, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
@ -22,7 +21,6 @@ from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
custom_prompt,
|
custom_prompt,
|
||||||
prompt_factory,
|
prompt_factory,
|
||||||
)
|
)
|
||||||
from litellm.llms.base import BaseLLM
|
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
|
@ -95,7 +93,7 @@ async def make_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class CodestralTextCompletion(BaseLLM):
|
class CodestralTextCompletion:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -139,7 +137,7 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
def process_text_completion_response(
|
def process_text_completion_response(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: httpx.Response,
|
||||||
model_response: TextCompletionResponse,
|
model_response: TextCompletionResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: LiteLLMLogging,
|
logging_obj: LiteLLMLogging,
|
||||||
|
@ -317,7 +315,7 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
|
|
||||||
### SYNC STREAMING
|
### SYNC STREAMING
|
||||||
if stream is True:
|
if stream is True:
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
completion_url,
|
completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
@ -333,7 +331,7 @@ class CodestralTextCompletion(BaseLLM):
|
||||||
### SYNC COMPLETION
|
### SYNC COMPLETION
|
||||||
else:
|
else:
|
||||||
|
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
url=completion_url,
|
url=completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
|
|
@ -6,8 +6,7 @@ import types
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
|
@ -491,7 +491,7 @@ class HTTPHandler:
|
||||||
self,
|
self,
|
||||||
url: str,
|
url: str,
|
||||||
data: Optional[Union[dict, str]] = None,
|
data: Optional[Union[dict, str]] = None,
|
||||||
json: Optional[Union[dict, str]] = None,
|
json: Optional[Union[dict, str, List]] = None,
|
||||||
params: Optional[dict] = None,
|
params: Optional[dict] = None,
|
||||||
headers: Optional[dict] = None,
|
headers: Optional[dict] = None,
|
||||||
stream: bool = False,
|
stream: bool = False,
|
||||||
|
|
|
@ -29,8 +29,7 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
@ -46,6 +45,7 @@ from litellm.utils import (
|
||||||
|
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class CustomLLMError(Exception): # use this for all your exceptions
|
class CustomLLMError(Exception): # use this for all your exceptions
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -6,7 +6,6 @@ from enum import Enum
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.utils import Choices, Message, ModelResponse, Usage
|
from litellm.utils import Choices, Message, ModelResponse, Usage
|
||||||
|
@ -240,7 +239,7 @@ def completion(
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
completion_url,
|
completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
|
|
@ -20,7 +20,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
|
|
@ -10,3 +10,36 @@ class OllamaError(BaseLLMException):
|
||||||
self, status_code: int, message: str, headers: Union[dict, httpx.Headers]
|
self, status_code: int, message: str, headers: Union[dict, httpx.Headers]
|
||||||
):
|
):
|
||||||
super().__init__(status_code=status_code, message=message, headers=headers)
|
super().__init__(status_code=status_code, message=message, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_image(image):
|
||||||
|
"""
|
||||||
|
Convert image to base64 encoded image if not already in base64 format
|
||||||
|
|
||||||
|
If image is already in base64 format AND is a jpeg/png, return it
|
||||||
|
|
||||||
|
If image is not JPEG/PNG, convert it to JPEG base64 format
|
||||||
|
"""
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
|
||||||
|
try:
|
||||||
|
from PIL import Image
|
||||||
|
except Exception:
|
||||||
|
raise Exception(
|
||||||
|
"ollama image conversion failed please run `pip install Pillow`"
|
||||||
|
)
|
||||||
|
|
||||||
|
orig = image
|
||||||
|
if image.startswith("data:"):
|
||||||
|
image = image.split(",")[-1]
|
||||||
|
try:
|
||||||
|
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
|
||||||
|
if image_data.format in ["JPEG", "PNG"]:
|
||||||
|
return image
|
||||||
|
except Exception:
|
||||||
|
return orig
|
||||||
|
jpeg_image = io.BytesIO()
|
||||||
|
image_data.convert("RGB").save(jpeg_image, "JPEG")
|
||||||
|
jpeg_image.seek(0)
|
||||||
|
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
|
||||||
|
|
|
@ -1,3 +1,9 @@
|
||||||
|
"""
|
||||||
|
Ollama /chat/completion calls handled in llm_http_handler.py
|
||||||
|
|
||||||
|
[TODO]: migrate embeddings to a base handler as well.
|
||||||
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
@ -8,10 +14,6 @@ from copy import deepcopy
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import httpx # type: ignore
|
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
@ -31,370 +33,8 @@ from litellm.types.utils import (
|
||||||
from ..common_utils import OllamaError
|
from ..common_utils import OllamaError
|
||||||
from .transformation import OllamaConfig
|
from .transformation import OllamaConfig
|
||||||
|
|
||||||
|
|
||||||
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
# ollama wants plain base64 jpeg/png files as images. strip any leading dataURI
|
||||||
# and convert to jpeg if necessary.
|
# and convert to jpeg if necessary.
|
||||||
def _convert_image(image):
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
|
|
||||||
try:
|
|
||||||
from PIL import Image
|
|
||||||
except Exception:
|
|
||||||
raise Exception(
|
|
||||||
"ollama image conversion failed please run `pip install Pillow`"
|
|
||||||
)
|
|
||||||
|
|
||||||
orig = image
|
|
||||||
if image.startswith("data:"):
|
|
||||||
image = image.split(",")[-1]
|
|
||||||
try:
|
|
||||||
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
|
|
||||||
if image_data.format in ["JPEG", "PNG"]:
|
|
||||||
return image
|
|
||||||
except Exception:
|
|
||||||
return orig
|
|
||||||
jpeg_image = io.BytesIO()
|
|
||||||
image_data.convert("RGB").save(jpeg_image, "JPEG")
|
|
||||||
jpeg_image.seek(0)
|
|
||||||
return base64.b64encode(jpeg_image.getvalue()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# ollama implementation
|
|
||||||
def get_ollama_response(
|
|
||||||
model_response: ModelResponse,
|
|
||||||
model: str,
|
|
||||||
prompt: str,
|
|
||||||
optional_params: dict,
|
|
||||||
logging_obj: Any,
|
|
||||||
encoding: Any,
|
|
||||||
acompletion: bool = False,
|
|
||||||
api_base="http://localhost:11434",
|
|
||||||
):
|
|
||||||
if api_base.endswith("/api/generate"):
|
|
||||||
url = api_base
|
|
||||||
else:
|
|
||||||
url = f"{api_base}/api/generate"
|
|
||||||
|
|
||||||
## Load Config
|
|
||||||
config = litellm.OllamaConfig.get_config()
|
|
||||||
for k, v in config.items():
|
|
||||||
if (
|
|
||||||
k not in optional_params
|
|
||||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
|
||||||
optional_params[k] = v
|
|
||||||
|
|
||||||
stream = optional_params.pop("stream", False)
|
|
||||||
format = optional_params.pop("format", None)
|
|
||||||
images = optional_params.pop("images", None)
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"prompt": prompt,
|
|
||||||
"options": optional_params,
|
|
||||||
"stream": stream,
|
|
||||||
}
|
|
||||||
if format is not None:
|
|
||||||
data["format"] = format
|
|
||||||
if images is not None:
|
|
||||||
data["images"] = [_convert_image(image) for image in images]
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=None,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"api_base": url,
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": {},
|
|
||||||
"acompletion": acompletion,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if acompletion is True:
|
|
||||||
if stream is True:
|
|
||||||
response = ollama_async_streaming(
|
|
||||||
url=url,
|
|
||||||
data=data,
|
|
||||||
model_response=model_response,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = ollama_acompletion(
|
|
||||||
url=url,
|
|
||||||
data=data,
|
|
||||||
model_response=model_response,
|
|
||||||
encoding=encoding,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
return response
|
|
||||||
elif stream is True:
|
|
||||||
return ollama_completion_stream(url=url, data=data, logging_obj=logging_obj)
|
|
||||||
|
|
||||||
response = requests.post(
|
|
||||||
url=f"{url}", json={**data, "stream": stream}, timeout=litellm.request_timeout
|
|
||||||
)
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise OllamaError(
|
|
||||||
status_code=response.status_code,
|
|
||||||
message=response.text,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=prompt,
|
|
||||||
api_key="",
|
|
||||||
original_response=response.text,
|
|
||||||
additional_args={
|
|
||||||
"headers": None,
|
|
||||||
"api_base": api_base,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response_json = response.json()
|
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
model_response.choices[0].finish_reason = "stop"
|
|
||||||
if data.get("format", "") == "json":
|
|
||||||
function_call = json.loads(response_json["response"])
|
|
||||||
message = litellm.Message(
|
|
||||||
content=None,
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
|
||||||
"function": {
|
|
||||||
"name": function_call["name"],
|
|
||||||
"arguments": json.dumps(function_call["arguments"]),
|
|
||||||
},
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
model_response.choices[0].message = message # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = "tool_calls"
|
|
||||||
else:
|
|
||||||
model_response.choices[0].message.content = response_json["response"] # type: ignore
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = "ollama/" + model
|
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(prompt, disallowed_special=()))) # type: ignore
|
|
||||||
completion_tokens = response_json.get(
|
|
||||||
"eval_count", len(response_json.get("message", dict()).get("content", ""))
|
|
||||||
)
|
|
||||||
setattr(
|
|
||||||
model_response,
|
|
||||||
"usage",
|
|
||||||
litellm.Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
def ollama_completion_stream(url, data, logging_obj):
|
|
||||||
with httpx.stream(
|
|
||||||
url=url, json=data, method="POST", timeout=litellm.request_timeout
|
|
||||||
) as response:
|
|
||||||
try:
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise OllamaError(
|
|
||||||
status_code=response.status_code,
|
|
||||||
message=str(response.read()),
|
|
||||||
headers=response.headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
|
||||||
completion_stream=response.iter_lines(),
|
|
||||||
model=data["model"],
|
|
||||||
custom_llm_provider="ollama",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
# If format is JSON, this was a function call
|
|
||||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
|
||||||
if data.get("format", "") == "json":
|
|
||||||
first_chunk = next(streamwrapper)
|
|
||||||
content_chunks = []
|
|
||||||
for chunk in chain([first_chunk], streamwrapper):
|
|
||||||
content_chunk = chunk.choices[0]
|
|
||||||
if (
|
|
||||||
isinstance(content_chunk, StreamingChoices)
|
|
||||||
and hasattr(content_chunk, "delta")
|
|
||||||
and hasattr(content_chunk.delta, "content")
|
|
||||||
and content_chunk.delta.content is not None
|
|
||||||
):
|
|
||||||
content_chunks.append(content_chunk.delta.content)
|
|
||||||
response_content = "".join(content_chunks)
|
|
||||||
|
|
||||||
function_call = json.loads(response_content)
|
|
||||||
delta = litellm.utils.Delta(
|
|
||||||
content=None,
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
|
||||||
"function": {
|
|
||||||
"name": function_call["name"],
|
|
||||||
"arguments": json.dumps(function_call["arguments"]),
|
|
||||||
},
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
model_response = first_chunk
|
|
||||||
model_response.choices[0].delta = delta # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = "tool_calls"
|
|
||||||
yield model_response
|
|
||||||
else:
|
|
||||||
for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
|
||||||
except Exception as e:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
||||||
async def ollama_async_streaming(url, data, model_response, encoding, logging_obj):
|
|
||||||
try:
|
|
||||||
_async_http_client = get_async_httpx_client(
|
|
||||||
llm_provider=litellm.LlmProviders.OLLAMA
|
|
||||||
)
|
|
||||||
client = _async_http_client.client
|
|
||||||
async with client.stream(
|
|
||||||
url=f"{url}", json=data, method="POST", timeout=litellm.request_timeout
|
|
||||||
) as response:
|
|
||||||
if response.status_code != 200:
|
|
||||||
raise OllamaError(
|
|
||||||
status_code=response.status_code,
|
|
||||||
message=str(await response.aread()),
|
|
||||||
headers=dict(response.headers),
|
|
||||||
)
|
|
||||||
|
|
||||||
streamwrapper = litellm.CustomStreamWrapper(
|
|
||||||
completion_stream=response.aiter_lines(),
|
|
||||||
model=data["model"],
|
|
||||||
custom_llm_provider="ollama",
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
)
|
|
||||||
|
|
||||||
# If format is JSON, this was a function call
|
|
||||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
|
||||||
if data.get("format", "") == "json":
|
|
||||||
first_chunk = await anext(streamwrapper) # noqa F821
|
|
||||||
chunk_choice = first_chunk.choices[0]
|
|
||||||
if (
|
|
||||||
isinstance(chunk_choice, StreamingChoices)
|
|
||||||
and hasattr(chunk_choice, "delta")
|
|
||||||
and hasattr(chunk_choice.delta, "content")
|
|
||||||
):
|
|
||||||
first_chunk_content = chunk_choice.delta.content or ""
|
|
||||||
else:
|
|
||||||
first_chunk_content = ""
|
|
||||||
|
|
||||||
content_chunks = []
|
|
||||||
async for chunk in streamwrapper:
|
|
||||||
chunk_choice = chunk.choices[0]
|
|
||||||
if (
|
|
||||||
isinstance(chunk_choice, StreamingChoices)
|
|
||||||
and hasattr(chunk_choice, "delta")
|
|
||||||
and hasattr(chunk_choice.delta, "content")
|
|
||||||
):
|
|
||||||
content_chunks.append(chunk_choice.delta.content)
|
|
||||||
response_content = first_chunk_content + "".join(content_chunks)
|
|
||||||
function_call = json.loads(response_content)
|
|
||||||
delta = litellm.utils.Delta(
|
|
||||||
content=None,
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
|
||||||
"function": {
|
|
||||||
"name": function_call["name"],
|
|
||||||
"arguments": json.dumps(function_call["arguments"]),
|
|
||||||
},
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
model_response = first_chunk
|
|
||||||
model_response.choices[0].delta = delta # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = "tool_calls"
|
|
||||||
yield model_response
|
|
||||||
else:
|
|
||||||
async for transformed_chunk in streamwrapper:
|
|
||||||
yield transformed_chunk
|
|
||||||
except Exception as e:
|
|
||||||
raise e # don't use verbose_logger.exception, if exception is raised
|
|
||||||
|
|
||||||
|
|
||||||
async def ollama_acompletion(
|
|
||||||
url, data, model_response: litellm.ModelResponse, encoding, logging_obj
|
|
||||||
):
|
|
||||||
data["stream"] = False
|
|
||||||
try:
|
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
resp = await session.post(url, json=data)
|
|
||||||
|
|
||||||
if resp.status != 200:
|
|
||||||
text = await resp.text()
|
|
||||||
raise OllamaError(
|
|
||||||
status_code=resp.status,
|
|
||||||
message=text,
|
|
||||||
headers=dict(resp.headers),
|
|
||||||
)
|
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=data["prompt"],
|
|
||||||
api_key="",
|
|
||||||
original_response=resp.text,
|
|
||||||
additional_args={
|
|
||||||
"headers": None,
|
|
||||||
"api_base": url,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response_json = await resp.json()
|
|
||||||
## RESPONSE OBJECT
|
|
||||||
model_response.choices[0].finish_reason = "stop"
|
|
||||||
if data.get("format", "") == "json":
|
|
||||||
function_call = json.loads(response_json["response"])
|
|
||||||
message = litellm.Message(
|
|
||||||
content=None,
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": f"call_{str(uuid.uuid4())}",
|
|
||||||
"function": {
|
|
||||||
"name": function_call.get(
|
|
||||||
"name", function_call.get("function", None)
|
|
||||||
),
|
|
||||||
"arguments": json.dumps(function_call["arguments"]),
|
|
||||||
},
|
|
||||||
"type": "function",
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
model_response.choices[0].message = message # type: ignore
|
|
||||||
model_response.choices[0].finish_reason = "tool_calls"
|
|
||||||
else:
|
|
||||||
model_response.choices[0].message.content = response_json["response"] # type: ignore
|
|
||||||
model_response.created = int(time.time())
|
|
||||||
model_response.model = "ollama/" + data["model"]
|
|
||||||
prompt_tokens = response_json.get("prompt_eval_count", len(encoding.encode(data["prompt"], disallowed_special=()))) # type: ignore
|
|
||||||
completion_tokens = response_json.get(
|
|
||||||
"eval_count",
|
|
||||||
len(response_json.get("message", dict()).get("content", "")),
|
|
||||||
)
|
|
||||||
setattr(
|
|
||||||
model_response,
|
|
||||||
"usage",
|
|
||||||
litellm.Usage(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
return model_response
|
|
||||||
except Exception as e:
|
|
||||||
raise e # don't use verbose_logger.exception, if exception is raised
|
|
||||||
|
|
||||||
|
|
||||||
async def ollama_aembeddings(
|
async def ollama_aembeddings(
|
||||||
api_base: str,
|
api_base: str,
|
||||||
|
@ -432,28 +72,7 @@ async def ollama_aembeddings(
|
||||||
total_input_tokens = 0
|
total_input_tokens = 0
|
||||||
output_data = []
|
output_data = []
|
||||||
|
|
||||||
timeout = aiohttp.ClientTimeout(total=litellm.request_timeout) # 10 minutes
|
response = await litellm.module_level_aclient.post(url=url, json=data)
|
||||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.pre_call(
|
|
||||||
input=None,
|
|
||||||
api_key=None,
|
|
||||||
additional_args={
|
|
||||||
"api_base": url,
|
|
||||||
"complete_input_dict": data,
|
|
||||||
"headers": {},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await session.post(url, json=data)
|
|
||||||
|
|
||||||
if response.status != 200:
|
|
||||||
text = await response.text()
|
|
||||||
raise OllamaError(
|
|
||||||
status_code=response.status,
|
|
||||||
message=text,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
)
|
|
||||||
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
|
|
||||||
|
|
|
@ -1,20 +1,34 @@
|
||||||
|
import json
|
||||||
|
import time
|
||||||
import types
|
import types
|
||||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
import uuid
|
||||||
|
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||||
|
|
||||||
from httpx._models import Headers, Response
|
from httpx._models import Headers, Response
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
convert_to_ollama_image,
|
||||||
|
custom_prompt,
|
||||||
|
ollama_pt,
|
||||||
|
)
|
||||||
|
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
ChatCompletionToolCallChunk,
|
||||||
|
ChatCompletionUsageBlock,
|
||||||
|
)
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
GenericStreamingChunk,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
ProviderField,
|
ProviderField,
|
||||||
StreamingChoices,
|
StreamingChoices,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ..common_utils import OllamaError
|
from ..common_utils import OllamaError, _convert_image
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||||
|
@ -247,7 +261,47 @@ class OllamaConfig(BaseConfig):
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
json_mode: Optional[bool] = None,
|
json_mode: Optional[bool] = None,
|
||||||
) -> ModelResponse:
|
) -> ModelResponse:
|
||||||
raise NotImplementedError("transformation currently done in handler.py")
|
response_json = raw_response.json()
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
model_response.choices[0].finish_reason = "stop"
|
||||||
|
if request_data.get("format", "") == "json":
|
||||||
|
function_call = json.loads(response_json["response"])
|
||||||
|
message = litellm.Message(
|
||||||
|
content=None,
|
||||||
|
tool_calls=[
|
||||||
|
{
|
||||||
|
"id": f"call_{str(uuid.uuid4())}",
|
||||||
|
"function": {
|
||||||
|
"name": function_call["name"],
|
||||||
|
"arguments": json.dumps(function_call["arguments"]),
|
||||||
|
},
|
||||||
|
"type": "function",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
model_response.choices[0].message = message # type: ignore
|
||||||
|
model_response.choices[0].finish_reason = "tool_calls"
|
||||||
|
else:
|
||||||
|
model_response.choices[0].message.content = response_json["response"] # type: ignore
|
||||||
|
model_response.created = int(time.time())
|
||||||
|
model_response.model = "ollama/" + model
|
||||||
|
_prompt = request_data.get("prompt", "")
|
||||||
|
prompt_tokens = response_json.get(
|
||||||
|
"prompt_eval_count", len(encoding.encode(_prompt, disallowed_special=())) # type: ignore
|
||||||
|
)
|
||||||
|
completion_tokens = response_json.get(
|
||||||
|
"eval_count", len(response_json.get("message", dict()).get("content", ""))
|
||||||
|
)
|
||||||
|
setattr(
|
||||||
|
model_response,
|
||||||
|
"usage",
|
||||||
|
litellm.Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return model_response
|
||||||
|
|
||||||
def transform_request(
|
def transform_request(
|
||||||
self,
|
self,
|
||||||
|
@ -257,7 +311,46 @@ class OllamaConfig(BaseConfig):
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
headers: dict,
|
headers: dict,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError("transformation currently done in handler.py")
|
custom_prompt_dict = (
|
||||||
|
litellm_params.get("custom_prompt_dict") or litellm.custom_prompt_dict
|
||||||
|
)
|
||||||
|
if model in custom_prompt_dict:
|
||||||
|
# check if the model has a registered custom prompt
|
||||||
|
model_prompt_details = custom_prompt_dict[model]
|
||||||
|
ollama_prompt = custom_prompt(
|
||||||
|
role_dict=model_prompt_details["roles"],
|
||||||
|
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||||
|
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
modified_prompt = ollama_pt(model=model, messages=messages)
|
||||||
|
if isinstance(modified_prompt, dict):
|
||||||
|
ollama_prompt, images = (
|
||||||
|
modified_prompt["prompt"],
|
||||||
|
modified_prompt["images"],
|
||||||
|
)
|
||||||
|
optional_params["images"] = images
|
||||||
|
else:
|
||||||
|
ollama_prompt = modified_prompt
|
||||||
|
stream = optional_params.pop("stream", False)
|
||||||
|
format = optional_params.pop("format", None)
|
||||||
|
images = optional_params.pop("images", None)
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"prompt": ollama_prompt,
|
||||||
|
"options": optional_params,
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
if format is not None:
|
||||||
|
data["format"] = format
|
||||||
|
if images is not None:
|
||||||
|
data["images"] = [
|
||||||
|
_convert_image(convert_to_ollama_image(image)) for image in images
|
||||||
|
]
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
def validate_environment(
|
def validate_environment(
|
||||||
self,
|
self,
|
||||||
|
@ -267,4 +360,77 @@ class OllamaConfig(BaseConfig):
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError("validation currently done in handler.py")
|
return headers
|
||||||
|
|
||||||
|
def get_complete_url(self, api_base: str, model: str) -> str:
|
||||||
|
"""
|
||||||
|
OPTIONAL
|
||||||
|
|
||||||
|
Get the complete url for the request
|
||||||
|
|
||||||
|
Some providers need `model` in `api_base`
|
||||||
|
"""
|
||||||
|
if api_base.endswith("/api/generate"):
|
||||||
|
url = api_base
|
||||||
|
else:
|
||||||
|
url = f"{api_base}/api/generate"
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
def get_model_response_iterator(
|
||||||
|
self,
|
||||||
|
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||||
|
sync_stream: bool,
|
||||||
|
json_mode: Optional[bool] = False,
|
||||||
|
):
|
||||||
|
return OllamaTextCompletionResponseIterator(
|
||||||
|
streaming_response=streaming_response,
|
||||||
|
sync_stream=sync_stream,
|
||||||
|
json_mode=json_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OllamaTextCompletionResponseIterator(BaseModelResponseIterator):
|
||||||
|
def _handle_string_chunk(self, str_line: str) -> GenericStreamingChunk:
|
||||||
|
return self.chunk_parser(json.loads(str_line))
|
||||||
|
|
||||||
|
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||||
|
try:
|
||||||
|
if "error" in chunk:
|
||||||
|
raise Exception(f"Ollama Error - {chunk}")
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
is_finished = False
|
||||||
|
finish_reason = None
|
||||||
|
if chunk["done"] is True:
|
||||||
|
text = ""
|
||||||
|
is_finished = True
|
||||||
|
finish_reason = "stop"
|
||||||
|
prompt_eval_count: Optional[int] = chunk.get("prompt_eval_count", None)
|
||||||
|
eval_count: Optional[int] = chunk.get("eval_count", None)
|
||||||
|
|
||||||
|
usage: Optional[ChatCompletionUsageBlock] = None
|
||||||
|
if prompt_eval_count is not None and eval_count is not None:
|
||||||
|
usage = ChatCompletionUsageBlock(
|
||||||
|
prompt_tokens=prompt_eval_count,
|
||||||
|
completion_tokens=eval_count,
|
||||||
|
total_tokens=prompt_eval_count + eval_count,
|
||||||
|
)
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
elif chunk["response"]:
|
||||||
|
text = chunk["response"]
|
||||||
|
return GenericStreamingChunk(
|
||||||
|
text=text,
|
||||||
|
is_finished=is_finished,
|
||||||
|
finish_reason="stop",
|
||||||
|
usage=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Unable to parse ollama chunk - {chunk}")
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
|
@ -8,7 +8,6 @@ from typing import Any, List, Optional
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import httpx
|
import httpx
|
||||||
import requests
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
@ -297,13 +296,14 @@ def get_ollama_response( # noqa: PLR0915
|
||||||
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
url=url, api_key=api_key, data=data, logging_obj=logging_obj
|
||||||
)
|
)
|
||||||
|
|
||||||
_request = {
|
headers: Optional[dict] = None
|
||||||
"url": f"{url}",
|
|
||||||
"json": data,
|
|
||||||
}
|
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
_request["headers"] = {"Authorization": "Bearer {}".format(api_key)}
|
headers = {"Authorization": "Bearer {}".format(api_key)}
|
||||||
response = requests.post(**_request) # type: ignore
|
response = litellm.module_level_client.post(
|
||||||
|
url=url,
|
||||||
|
json=data,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise OllamaError(status_code=response.status_code, message=response.text)
|
raise OllamaError(status_code=response.status_code, message=response.text)
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,14 @@ import time
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import requests # type: ignore
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler, _get_httpx_client
|
||||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage
|
from litellm.utils import EmbeddingResponse, ModelResponse, Usage
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
from ..common_utils import OobaboogaError
|
from ..common_utils import OobaboogaError
|
||||||
from .transformation import OobaboogaConfig
|
from .transformation import OobaboogaConfig
|
||||||
|
|
||||||
|
@ -129,9 +131,9 @@ def embedding(
|
||||||
messages=[],
|
messages=[],
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
)
|
)
|
||||||
response = requests.post(embeddings_url, headers=headers, json=data)
|
response = litellm.module_level_client.post(
|
||||||
if not response.ok:
|
embeddings_url, headers=headers, json=data
|
||||||
raise OobaboogaError(message=response.text, status_code=response.status_code)
|
)
|
||||||
completion_response = response.json()
|
completion_response = response.json()
|
||||||
|
|
||||||
# Check for errors in response
|
# Check for errors in response
|
||||||
|
|
|
@ -13,8 +13,7 @@ from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
from typing import Any, Callable, List, Literal, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import LlmProviders
|
from litellm import LlmProviders
|
||||||
|
|
|
@ -12,7 +12,6 @@ from functools import partial
|
||||||
from typing import Callable, List, Literal, Optional, Union
|
from typing import Callable, List, Literal, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
|
@ -63,7 +62,7 @@ async def make_call(
|
||||||
return completion_stream
|
return completion_stream
|
||||||
|
|
||||||
|
|
||||||
class PredibaseChatCompletion(BaseLLM):
|
class PredibaseChatCompletion:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@ -90,7 +89,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
def process_response( # noqa: PLR0915
|
def process_response( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
response: Union[requests.Response, httpx.Response],
|
response: httpx.Response,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
logging_obj: LiteLLMLoggingBaseClass,
|
logging_obj: LiteLLMLoggingBaseClass,
|
||||||
|
@ -347,7 +346,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
|
|
||||||
### SYNC STREAMING
|
### SYNC STREAMING
|
||||||
if stream is True:
|
if stream is True:
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
completion_url,
|
completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
@ -363,7 +362,7 @@ class PredibaseChatCompletion(BaseLLM):
|
||||||
return _response
|
return _response
|
||||||
### SYNC COMPLETION
|
### SYNC COMPLETION
|
||||||
else:
|
else:
|
||||||
response = requests.post(
|
response = litellm.module_level_client.post(
|
||||||
url=completion_url,
|
url=completion_url,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
data=json.dumps(data),
|
data=json.dumps(data),
|
||||||
|
|
|
@ -10,12 +10,16 @@ from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.asyncify import asyncify
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
|
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
|
@ -31,8 +35,6 @@ from litellm.utils import (
|
||||||
get_secret,
|
get_secret,
|
||||||
)
|
)
|
||||||
|
|
||||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
from ..common_utils import AWSEventStreamDecoder, SagemakerError
|
from ..common_utils import AWSEventStreamDecoder, SagemakerError
|
||||||
from .transformation import SagemakerConfig
|
from .transformation import SagemakerConfig
|
||||||
|
|
||||||
|
|
|
@ -24,23 +24,22 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
convert_generic_image_chunk_to_openai_image_obj,
|
||||||
|
convert_to_anthropic_image_obj,
|
||||||
|
)
|
||||||
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.transformation import BaseConfig, BaseLLMException
|
||||||
from litellm.llms.custom_httpx.http_handler import (
|
from litellm.llms.custom_httpx.http_handler import (
|
||||||
AsyncHTTPHandler,
|
AsyncHTTPHandler,
|
||||||
HTTPHandler,
|
HTTPHandler,
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
|
||||||
convert_generic_image_chunk_to_openai_image_obj,
|
|
||||||
convert_to_anthropic_image_obj,
|
|
||||||
)
|
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
|
|
|
@ -7,19 +7,18 @@ import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Literal, Optional, Union, cast
|
from typing import Any, Callable, List, Literal, Optional, Union, cast
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
convert_to_anthropic_image_obj,
|
convert_to_anthropic_image_obj,
|
||||||
convert_to_gemini_tool_call_invoke,
|
convert_to_gemini_tool_call_invoke,
|
||||||
convert_to_gemini_tool_call_result,
|
convert_to_gemini_tool_call_result,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS
|
||||||
from litellm.types.files import (
|
from litellm.types.files import (
|
||||||
get_file_mime_type_for_file_type,
|
get_file_mime_type_for_file_type,
|
||||||
get_file_type_from_extension,
|
get_file_type_from_extension,
|
||||||
|
|
|
@ -9,11 +9,19 @@ import uuid
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
construct_tool_use_system_prompt,
|
||||||
|
contains_tag,
|
||||||
|
custom_prompt,
|
||||||
|
extract_between_tags,
|
||||||
|
parse_xml_params,
|
||||||
|
prompt_factory,
|
||||||
|
response_schema_prompt,
|
||||||
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.types.llms.openai import (
|
from litellm.types.llms.openai import (
|
||||||
AllMessageValues,
|
AllMessageValues,
|
||||||
|
@ -24,15 +32,6 @@ from litellm.types.utils import ResponseFormatChunk
|
||||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||||
|
|
||||||
from ....anthropic.chat.transformation import AnthropicConfig
|
from ....anthropic.chat.transformation import AnthropicConfig
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
|
||||||
construct_tool_use_system_prompt,
|
|
||||||
contains_tag,
|
|
||||||
custom_prompt,
|
|
||||||
extract_between_tags,
|
|
||||||
parse_xml_params,
|
|
||||||
prompt_factory,
|
|
||||||
response_schema_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VertexAIError(Exception):
|
class VertexAIError(Exception):
|
||||||
|
|
|
@ -5,12 +5,13 @@ from enum import Enum
|
||||||
from typing import Any, Callable
|
from typing import Any, Callable
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
import requests # type: ignore
|
|
||||||
|
|
||||||
|
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||||
|
custom_prompt,
|
||||||
|
prompt_factory,
|
||||||
|
)
|
||||||
from litellm.utils import ModelResponse, Usage
|
from litellm.utils import ModelResponse, Usage
|
||||||
|
|
||||||
from litellm.litellm_core_utils.prompt_templates.factory import custom_prompt, prompt_factory
|
|
||||||
|
|
||||||
llm = None
|
llm = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2645,45 +2645,24 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
or get_secret("OLLAMA_API_BASE")
|
or get_secret("OLLAMA_API_BASE")
|
||||||
or "http://localhost:11434"
|
or "http://localhost:11434"
|
||||||
)
|
)
|
||||||
custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict
|
response = base_llm_http_handler.completion(
|
||||||
if model in custom_prompt_dict:
|
|
||||||
# check if the model has a registered custom prompt
|
|
||||||
model_prompt_details = custom_prompt_dict[model]
|
|
||||||
ollama_prompt = custom_prompt(
|
|
||||||
role_dict=model_prompt_details["roles"],
|
|
||||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
|
||||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
modified_prompt = ollama_pt(model=model, messages=messages)
|
|
||||||
if isinstance(modified_prompt, dict):
|
|
||||||
# for multimode models - ollama/llava prompt_factory returns a dict {
|
|
||||||
# "prompt": prompt,
|
|
||||||
# "images": images
|
|
||||||
# }
|
|
||||||
ollama_prompt, images = (
|
|
||||||
modified_prompt["prompt"],
|
|
||||||
modified_prompt["images"],
|
|
||||||
)
|
|
||||||
optional_params["images"] = images
|
|
||||||
else:
|
|
||||||
ollama_prompt = modified_prompt
|
|
||||||
## LOGGING
|
|
||||||
generator = ollama.get_ollama_response(
|
|
||||||
api_base=api_base,
|
|
||||||
model=model,
|
model=model,
|
||||||
prompt=ollama_prompt,
|
stream=stream,
|
||||||
optional_params=optional_params,
|
messages=messages,
|
||||||
logging_obj=logging,
|
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
|
api_base=api_base,
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
custom_llm_provider="ollama",
|
||||||
|
timeout=timeout,
|
||||||
|
headers=headers,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
|
api_key=api_key,
|
||||||
|
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if acompletion is True or optional_params.get("stream", False) is True:
|
|
||||||
return generator
|
|
||||||
|
|
||||||
response = generator
|
|
||||||
elif custom_llm_provider == "ollama_chat":
|
elif custom_llm_provider == "ollama_chat":
|
||||||
api_base = (
|
api_base = (
|
||||||
litellm.api_base
|
litellm.api_base
|
||||||
|
@ -2833,8 +2812,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
return response
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "custom":
|
elif custom_llm_provider == "custom":
|
||||||
import requests
|
|
||||||
|
|
||||||
url = litellm.api_base or api_base or ""
|
url = litellm.api_base or api_base or ""
|
||||||
if url is None or url == "":
|
if url is None or url == "":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -2843,7 +2820,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
"""
|
"""
|
||||||
assume input to custom LLM api bases follow this format:
|
assume input to custom LLM api bases follow this format:
|
||||||
resp = requests.post(
|
resp = litellm.module_level_client.post(
|
||||||
api_base,
|
api_base,
|
||||||
json={
|
json={
|
||||||
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
'model': 'meta-llama/Llama-2-13b-hf', # model name
|
||||||
|
@ -2859,7 +2836,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
|
|
||||||
"""
|
"""
|
||||||
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
prompt = " ".join([message["content"] for message in messages]) # type: ignore
|
||||||
resp = requests.post(
|
resp = litellm.module_level_client.post(
|
||||||
url,
|
url,
|
||||||
json={
|
json={
|
||||||
"model": model,
|
"model": model,
|
||||||
|
@ -2871,7 +2848,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"top_k": kwargs.get("top_k", 40),
|
"top_k": kwargs.get("top_k", 40),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
verify=litellm.ssl_verify,
|
|
||||||
)
|
)
|
||||||
response_json = resp.json()
|
response_json = resp.json()
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -303,7 +303,7 @@ def run_server( # noqa: PLR0915
|
||||||
return
|
return
|
||||||
if model and "ollama" in model and api_base is None:
|
if model and "ollama" in model and api_base is None:
|
||||||
run_ollama_serve()
|
run_ollama_serve()
|
||||||
import requests
|
import httpx
|
||||||
|
|
||||||
if test_async is True:
|
if test_async is True:
|
||||||
import concurrent
|
import concurrent
|
||||||
|
@ -319,7 +319,7 @@ def run_server( # noqa: PLR0915
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
response = requests.post("http://0.0.0.0:4000/queue/request", json=data)
|
response = httpx.post("http://0.0.0.0:4000/queue/request", json=data)
|
||||||
|
|
||||||
response = response.json()
|
response = response.json()
|
||||||
|
|
||||||
|
@ -327,7 +327,7 @@ def run_server( # noqa: PLR0915
|
||||||
try:
|
try:
|
||||||
url = response["url"]
|
url = response["url"]
|
||||||
polling_url = f"{api_base}{url}"
|
polling_url = f"{api_base}{url}"
|
||||||
polling_response = requests.get(polling_url)
|
polling_response = httpx.get(polling_url)
|
||||||
polling_response = polling_response.json()
|
polling_response = polling_response.json()
|
||||||
print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa
|
print("\n RESPONSE FROM POLLING JOB", polling_response) # noqa
|
||||||
status = polling_response["status"]
|
status = polling_response["status"]
|
||||||
|
@ -378,7 +378,7 @@ def run_server( # noqa: PLR0915
|
||||||
if health is not False:
|
if health is not False:
|
||||||
|
|
||||||
print("\nLiteLLM: Health Testing models in config") # noqa
|
print("\nLiteLLM: Health Testing models in config") # noqa
|
||||||
response = requests.get(url=f"http://{host}:{port}/health")
|
response = httpx.get(url=f"http://{host}:{port}/health")
|
||||||
print(json.dumps(response.json(), indent=4)) # noqa
|
print(json.dumps(response.json(), indent=4)) # noqa
|
||||||
return
|
return
|
||||||
if test is not False:
|
if test is not False:
|
||||||
|
|
|
@ -11,9 +11,6 @@ import random
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import dotenv # type: ignore
|
|
||||||
import requests
|
|
||||||
|
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,6 @@ import traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import dotenv
|
|
||||||
import requests
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from litellm import token_counter
|
from litellm import token_counter
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
|
|
@ -43,7 +43,6 @@ import aiohttp
|
||||||
import dotenv
|
import dotenv
|
||||||
import httpx
|
import httpx
|
||||||
import openai
|
import openai
|
||||||
import requests
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from httpx import Proxy
|
from httpx import Proxy
|
||||||
from httpx._utils import get_environment_proxies
|
from httpx._utils import get_environment_proxies
|
||||||
|
@ -4175,7 +4174,7 @@ def get_max_tokens(model: str) -> Optional[int]:
|
||||||
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
|
config_url = f"https://huggingface.co/{model_name}/raw/main/config.json"
|
||||||
try:
|
try:
|
||||||
# Make the HTTP request to get the raw JSON file
|
# Make the HTTP request to get the raw JSON file
|
||||||
response = requests.get(config_url)
|
response = litellm.module_level_client.get(config_url)
|
||||||
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
|
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
|
||||||
|
|
||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
|
@ -4186,7 +4185,7 @@ def get_max_tokens(model: str) -> Optional[int]:
|
||||||
return max_position_embeddings
|
return max_position_embeddings
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except requests.exceptions.RequestException:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -4361,7 +4360,7 @@ def get_model_info( # noqa: PLR0915
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Make the HTTP request to get the raw JSON file
|
# Make the HTTP request to get the raw JSON file
|
||||||
response = requests.get(config_url)
|
response = litellm.module_level_client.get(config_url)
|
||||||
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
|
response.raise_for_status() # Raise an exception for bad responses (4xx or 5xx)
|
||||||
|
|
||||||
# Parse the JSON response
|
# Parse the JSON response
|
||||||
|
@ -4374,7 +4373,7 @@ def get_model_info( # noqa: PLR0915
|
||||||
return max_position_embeddings
|
return max_position_embeddings
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
except requests.exceptions.RequestException:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
183
tests/documentation_tests/test_requests_lib_usage.py
Normal file
183
tests/documentation_tests/test_requests_lib_usage.py
Normal file
|
@ -0,0 +1,183 @@
|
||||||
|
"""
|
||||||
|
Prevent usage of 'requests' library in the codebase.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import ast
|
||||||
|
import sys
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def find_requests_usage(directory: str) -> List[Tuple[str, int, str]]:
|
||||||
|
"""
|
||||||
|
Recursively search for Python files in the given directory
|
||||||
|
and find usages of the 'requests' library.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
directory (str): The root directory to search for Python files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tuples containing (file_path, line_number, usage_type)
|
||||||
|
"""
|
||||||
|
requests_usages = []
|
||||||
|
|
||||||
|
def is_likely_requests_usage(node):
|
||||||
|
"""
|
||||||
|
More precise check to avoid false positives
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Convert node to string representation
|
||||||
|
node_str = ast.unparse(node)
|
||||||
|
|
||||||
|
# Specific checks to ensure it's the requests library
|
||||||
|
requests_identifiers = [
|
||||||
|
# HTTP methods
|
||||||
|
"requests.get",
|
||||||
|
"requests.post",
|
||||||
|
"requests.put",
|
||||||
|
"requests.delete",
|
||||||
|
"requests.head",
|
||||||
|
"requests.patch",
|
||||||
|
"requests.options",
|
||||||
|
"requests.request",
|
||||||
|
"requests.session",
|
||||||
|
# Types and exceptions
|
||||||
|
"requests.Response",
|
||||||
|
"requests.Request",
|
||||||
|
"requests.Session",
|
||||||
|
"requests.ConnectionError",
|
||||||
|
"requests.HTTPError",
|
||||||
|
"requests.Timeout",
|
||||||
|
"requests.TooManyRedirects",
|
||||||
|
"requests.RequestException",
|
||||||
|
# Additional modules and attributes
|
||||||
|
"requests.api",
|
||||||
|
"requests.exceptions",
|
||||||
|
"requests.models",
|
||||||
|
"requests.auth",
|
||||||
|
"requests.cookies",
|
||||||
|
"requests.structures",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Check for specific requests library identifiers
|
||||||
|
return any(identifier in node_str for identifier in requests_identifiers)
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def scan_file(file_path: str):
|
||||||
|
"""
|
||||||
|
Scan a single Python file for requests library usage
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use utf-8-sig to handle files with BOM, ignore errors
|
||||||
|
with open(file_path, "r", encoding="utf-8-sig", errors="ignore") as file:
|
||||||
|
tree = ast.parse(file.read())
|
||||||
|
|
||||||
|
for node in ast.walk(tree):
|
||||||
|
# Check import statements
|
||||||
|
if isinstance(node, ast.Import):
|
||||||
|
for alias in node.names:
|
||||||
|
if alias.name == "requests":
|
||||||
|
requests_usages.append(
|
||||||
|
(file_path, node.lineno, f"Import: {alias.name}")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check import from statements
|
||||||
|
elif isinstance(node, ast.ImportFrom):
|
||||||
|
if node.module == "requests":
|
||||||
|
requests_usages.append(
|
||||||
|
(file_path, node.lineno, f"Import from: {node.module}")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check method calls
|
||||||
|
elif isinstance(node, ast.Call):
|
||||||
|
# More precise check for requests usage
|
||||||
|
try:
|
||||||
|
if is_likely_requests_usage(node.func):
|
||||||
|
requests_usages.append(
|
||||||
|
(
|
||||||
|
file_path,
|
||||||
|
node.lineno,
|
||||||
|
f"Method Call: {ast.unparse(node.func)}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check attribute access
|
||||||
|
elif isinstance(node, ast.Attribute):
|
||||||
|
try:
|
||||||
|
# More precise check
|
||||||
|
if is_likely_requests_usage(node):
|
||||||
|
requests_usages.append(
|
||||||
|
(
|
||||||
|
file_path,
|
||||||
|
node.lineno,
|
||||||
|
f"Attribute Access: {ast.unparse(node)}",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
except SyntaxError as e:
|
||||||
|
print(f"Syntax error in {file_path}: {e}", file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {file_path}: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
# Recursively walk through directory
|
||||||
|
for root, dirs, files in os.walk(directory):
|
||||||
|
# Remove virtual environment and cache directories from search
|
||||||
|
dirs[:] = [
|
||||||
|
d
|
||||||
|
for d in dirs
|
||||||
|
if not any(
|
||||||
|
venv in d
|
||||||
|
for venv in [
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
"myenv",
|
||||||
|
".venv",
|
||||||
|
"__pycache__",
|
||||||
|
".pytest_cache",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
if file.endswith(".py"):
|
||||||
|
full_path = os.path.join(root, file)
|
||||||
|
# Skip files in virtual environment or cache directories
|
||||||
|
if not any(
|
||||||
|
venv in full_path
|
||||||
|
for venv in [
|
||||||
|
"venv",
|
||||||
|
"env",
|
||||||
|
"myenv",
|
||||||
|
".venv",
|
||||||
|
"__pycache__",
|
||||||
|
".pytest_cache",
|
||||||
|
]
|
||||||
|
):
|
||||||
|
scan_file(full_path)
|
||||||
|
|
||||||
|
return requests_usages
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Get directory from command line argument or use current directory
|
||||||
|
directory = "../../litellm"
|
||||||
|
|
||||||
|
# Find requests library usages
|
||||||
|
results = find_requests_usage(directory)
|
||||||
|
|
||||||
|
# Print results
|
||||||
|
if results:
|
||||||
|
print("Requests Library Usages Found:")
|
||||||
|
for file_path, line_num, usage_type in results:
|
||||||
|
print(f"{file_path}:{line_num} - {usage_type}")
|
||||||
|
else:
|
||||||
|
print("No requests library usages found.")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -1940,10 +1940,11 @@ def test_ollama_image():
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.status_code = 200
|
mock_response.status_code = 200
|
||||||
mock_response.headers = {"Content-Type": "application/json"}
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
data_json = json.loads(kwargs["data"])
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
# return the image in the response so that it can be tested
|
# return the image in the response so that it can be tested
|
||||||
# against the original
|
# against the original
|
||||||
"response": kwargs["json"]["images"]
|
"response": data_json["images"]
|
||||||
}
|
}
|
||||||
return mock_response
|
return mock_response
|
||||||
|
|
||||||
|
@ -1971,9 +1972,10 @@ def test_ollama_image():
|
||||||
[datauri_base64_data, datauri_base64_data],
|
[datauri_base64_data, datauri_base64_data],
|
||||||
]
|
]
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
for test in tests:
|
for test in tests:
|
||||||
try:
|
try:
|
||||||
with patch("requests.post", side_effect=mock_post):
|
with patch.object(client, "post", side_effect=mock_post):
|
||||||
response = completion(
|
response = completion(
|
||||||
model="ollama/llava",
|
model="ollama/llava",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -1988,6 +1990,7 @@ def test_ollama_image():
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
client=client,
|
||||||
)
|
)
|
||||||
if not test[1]:
|
if not test[1]:
|
||||||
# the conversion process may not always generate the same image,
|
# the conversion process may not always generate the same image,
|
||||||
|
@ -2387,8 +2390,8 @@ def test_completion_ollama_hosted():
|
||||||
response = completion(
|
response = completion(
|
||||||
model="ollama/phi",
|
model="ollama/phi",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=2,
|
max_tokens=20,
|
||||||
api_base="https://test-ollama-endpoint.onrender.com",
|
# api_base="https://test-ollama-endpoint.onrender.com",
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
print(response)
|
print(response)
|
||||||
|
|
|
@ -606,14 +606,14 @@ def test_completion_azure_function_calling_stream():
|
||||||
@pytest.mark.skip("Flaky ollama test - needs to be fixed")
|
@pytest.mark.skip("Flaky ollama test - needs to be fixed")
|
||||||
def test_completion_ollama_hosted_stream():
|
def test_completion_ollama_hosted_stream():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
# litellm.set_verbose = True
|
||||||
response = completion(
|
response = completion(
|
||||||
model="ollama/phi",
|
model="ollama/phi",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
max_tokens=10,
|
max_tokens=100,
|
||||||
num_retries=3,
|
num_retries=3,
|
||||||
timeout=20,
|
timeout=20,
|
||||||
api_base="https://test-ollama-endpoint.onrender.com",
|
# api_base="https://test-ollama-endpoint.onrender.com",
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue