271 líneas
8.1 KiB
Python
271 líneas
8.1 KiB
Python
import os
|
|
import requests
|
|
from requests.models import urlencode
|
|
import pdb
|
|
|
|
###
|
|
# Dict helper class.
|
|
# Defined at top level so it can be pickled.
|
|
###
|
|
class AttribAccessDict(dict):
|
|
def __getattr__(self, attr):
|
|
if attr in self:
|
|
return self[attr]
|
|
else:
|
|
raise AttributeError("Attribute not found: " + str(attr))
|
|
|
|
def __setattr__(self, attr, val):
|
|
if attr in self:
|
|
raise AttributeError("Attribute-style access is read only")
|
|
super(AttribAccessDict, self).__setattr__(attr, val)
|
|
|
|
class Xat:
|
|
|
|
name = 'OpenAI API Python wrapper'
|
|
|
|
def __init__(self, api_base_url=None, access_token=None, session=None, headers=None):
|
|
|
|
self.__xatbot_config_path = "config/xatbot.txt"
|
|
|
|
is_setup = self.__check_setup(self)
|
|
|
|
if is_setup:
|
|
|
|
self.api_base_url = self.__get_parameter("api_base_url", self.__xatbot_config_path)
|
|
self.access_token = self.__get_parameter("access_token", self.__xatbot_config_path)
|
|
|
|
else:
|
|
|
|
self.api_base_url, self.access_token = self.__setup(self)
|
|
|
|
self.headers = {"Authorization": "Bearer " + self.access_token} if self.access_token else {}
|
|
|
|
if session:
|
|
self.session = session
|
|
else:
|
|
self.session = requests.Session()
|
|
|
|
def completions(self, model=None, prompt=None):
|
|
|
|
data = {
|
|
'model': "text-davinci-003",
|
|
'prompt': prompt,
|
|
'max_tokens': 750,
|
|
'temperature': 0
|
|
}
|
|
|
|
endpoint = self.api_base_url + '/v1/completions'
|
|
|
|
response = self.__api_request('POST', endpoint, data)
|
|
|
|
response = self.__json_allow_dict_attrs(response.json())
|
|
|
|
return response
|
|
|
|
def models(self):
|
|
|
|
data = {}
|
|
|
|
endpoint = self.api_base_url + '/v1/models'
|
|
|
|
response = self.__api_request('GET', endpoint, data)
|
|
|
|
response = self.__json_allow_dict_attrs(response.json())
|
|
|
|
return response
|
|
|
|
def get_model(self, model=None):
|
|
|
|
data = {}
|
|
|
|
endpoint = f'{self.api_base_url}/v1/models/{model}'
|
|
|
|
response = self.__api_request('GET', endpoint, data)
|
|
|
|
response = self.__json_allow_dict_attrs(response.json())
|
|
|
|
return response
|
|
|
|
def create_image(self, model=None, prompt=None):
|
|
|
|
data = {
|
|
'prompt': prompt,
|
|
'n': 1,
|
|
'size': "512x512",
|
|
'response_format': "b64_json"
|
|
}
|
|
|
|
endpoint = f'{self.api_base_url}/v1/images/generations'
|
|
|
|
response = self.__api_request('POST', endpoint, data)
|
|
|
|
response = self.__json_allow_dict_attrs(response.json())
|
|
|
|
return response
|
|
|
|
@staticmethod
|
|
def __check_setup(self):
|
|
|
|
is_setup = False
|
|
|
|
if not os.path.isfile(self.__xatbot_config_path):
|
|
print(f"File {self.__xatbot_config_path} not found, running setup.")
|
|
else:
|
|
is_setup = True
|
|
|
|
return is_setup
|
|
|
|
@staticmethod
|
|
def __setup(self):
|
|
|
|
if not os.path.exists('config'):
|
|
os.makedirs('config')
|
|
|
|
self.api_base_url = input("Openai API url, in ex. 'https://api.openai.com': ")
|
|
self.access_token = input("Openai access token: ")
|
|
|
|
if not os.path.exists(self.__xatbot_config_path):
|
|
with open(self.__xatbot_config_path, 'w'): pass
|
|
print(f"{self.__xatbot_config_path} created!")
|
|
|
|
with open(self.__xatbot_config_path, 'a') as the_file:
|
|
print("Writing xatbot parameters to " + self.__xatbot_config_path)
|
|
the_file.write(f'api_base_url: {self.api_base_url}\n'+f'access_token: {self.access_token}\n')
|
|
|
|
return (self.api_base_url, self.access_token)
|
|
|
|
@staticmethod
|
|
def __get_parameter(parameter, file_path ):
|
|
|
|
with open( file_path ) as f:
|
|
for line in f:
|
|
if line.startswith( parameter ):
|
|
return line.replace(parameter + ":", "").strip()
|
|
|
|
print(f'{file_path} Missing parameter {parameter}')
|
|
sys.exit(0)
|
|
|
|
def __api_request(self, method, endpoint, data): #data={}):
|
|
|
|
response = None
|
|
|
|
try:
|
|
|
|
#kwargs = dict(data=data)
|
|
|
|
response = self.session.request(method, url = endpoint, headers = self.headers, json=data) #**kwargs)
|
|
|
|
except Exception as e:
|
|
|
|
raise OpenaiNetworkError(f"Could not complete request: {e}")
|
|
|
|
if response is None:
|
|
|
|
raise OpenaiIllegalArgumentError("Illegal request.")
|
|
|
|
if not response.ok:
|
|
|
|
try:
|
|
if isinstance(response, dict) and 'error' in response:
|
|
error_msg = response['error']
|
|
elif isinstance(response, str):
|
|
error_msg = response
|
|
else:
|
|
error_msg = None
|
|
except ValueError:
|
|
error_msg = None
|
|
|
|
if response.status_code == 404:
|
|
ex_type = OpenaiNotFoundError
|
|
if not error_msg:
|
|
error_msg = 'Endpoint not found.'
|
|
# this is for compatibility with older versions
|
|
# which raised OpenaiAPIError('Endpoint not found.')
|
|
# on any 404
|
|
elif response.status_code == 400:
|
|
return response
|
|
elif response.status_code == 401:
|
|
ex_type = OpenaiUnauthorizedError
|
|
elif response.status_code == 422:
|
|
return response
|
|
elif response.status_code == 500:
|
|
ex_type = OpenaiInternalServerError
|
|
elif response.status_code == 502:
|
|
ex_type = OpenaiBadGatewayError
|
|
elif response.status_code == 503:
|
|
#ex_type = OpenaiServiceUnavailableError
|
|
return response
|
|
elif response.status_code == 504:
|
|
ex_type = OpenaiGatewayTimeoutError
|
|
elif response.status_code >= 500 and \
|
|
response.status_code <= 511:
|
|
ex_type = OpenaiServerError
|
|
else:
|
|
ex_type = OpenaiAPIError
|
|
|
|
raise ex_type(
|
|
'Openai API returned error',
|
|
response.status_code,
|
|
response.reason,
|
|
error_msg)
|
|
|
|
return response
|
|
|
|
else:
|
|
|
|
return response
|
|
|
|
@staticmethod
|
|
def __json_allow_dict_attrs(json_object):
|
|
"""
|
|
Makes it possible to use attribute notation to access a dicts
|
|
elements, while still allowing the dict to act as a dict.
|
|
"""
|
|
if isinstance(json_object, dict):
|
|
return AttribAccessDict(json_object)
|
|
return json_object
|
|
|
|
##
|
|
# Exceptions
|
|
##
|
|
class OpenaiError(Exception):
|
|
"""Base class for Openai.py exceptions"""
|
|
|
|
class OpenaiIOError(IOError, OpenaiError):
|
|
"""Base class for Openai.py I/O errors"""
|
|
|
|
class OpenaiNetworkError(OpenaiIOError):
|
|
"""Raised when network communication with the server fails"""
|
|
pass
|
|
class OpenaiAPIError(OpenaiError):
|
|
"""Raised when the forgejo API generates a response that cannot be handled"""
|
|
pass
|
|
class OpenaiServerError(OpenaiAPIError):
|
|
"""Raised if the Server is malconfigured and returns a 5xx error code"""
|
|
pass
|
|
class OpenaiInternalServerError(OpenaiServerError):
|
|
"""Raised if the Server returns a 500 error"""
|
|
pass
|
|
|
|
class OpenaiBadGatewayError(OpenaiServerError):
|
|
"""Raised if the Server returns a 502 error"""
|
|
pass
|
|
|
|
class OpenaiServiceUnavailableError(OpenaiServerError):
|
|
"""Raised if the Server returns a 503 error"""
|
|
pass
|
|
class OpenaiGatewayTimeoutError(OpenaiServerError):
|
|
"""Raised if the Server returns a 504 error"""
|
|
pass
|
|
class OpenaiNotFoundError(OpenaiAPIError):
|
|
"""Raised when the forgejo API returns a 404 Not Found error"""
|
|
pass
|
|
|
|
class OpenaiUnauthorizedError(OpenaiAPIError):
|
|
"""Raised when the forgejo API returns a 401 Unauthorized error
|
|
|
|
This happens when an OAuth token is invalid or has been revoked,
|
|
or when trying to access an endpoint that can't be used without
|
|
authentication without providing credentials."""
|
|
pass
|