fix(ollama.py): fix type issue

This commit is contained in:
Krrish Dholakia 2024-03-28 15:01:56 -07:00
parent 90de9d3c10
commit 48af367885

View file

@ -68,9 +68,9 @@ class OllamaConfig:
repeat_last_n: Optional[int] = None repeat_last_n: Optional[int] = None
repeat_penalty: Optional[float] = None repeat_penalty: Optional[float] = None
temperature: Optional[float] = None temperature: Optional[float] = None
stop: Optional[ stop: Optional[list] = (
list None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442
] = None # stop is a list based on this - https://github.com/jmorganca/ollama/pull/442 )
tfs_z: Optional[float] = None tfs_z: Optional[float] = None
num_predict: Optional[int] = None num_predict: Optional[int] = None
top_k: Optional[int] = None top_k: Optional[int] = None
@ -346,7 +346,7 @@ async def ollama_acompletion(url, data, model_response, encoding, logging_obj):
async def ollama_aembeddings( async def ollama_aembeddings(
api_base: str, api_base: str,
model: str, model: str,
prompts: list[str], prompts: list,
optional_params=None, optional_params=None,
logging_obj=None, logging_obj=None,
model_response=None, model_response=None,
@ -378,9 +378,13 @@ async def ollama_aembeddings(
logging_obj.pre_call( logging_obj.pre_call(
input=None, input=None,
api_key=None, api_key=None,
additional_args={"api_base": url, "complete_input_dict": data, "headers": {}}, additional_args={
"api_base": url,
"complete_input_dict": data,
"headers": {},
},
) )
response = await session.post(url, json=data) response = await session.post(url, json=data)
if response.status != 200: if response.status != 200:
text = await response.text() text = await response.text()