import datetime
import logging
import os
import threading
import time
from typing import Iterator, Optional, Union
import openai
from banterbot.config import RETRY_LIMIT, RETRY_TIME
from banterbot.data.enums import EnvVar
from banterbot.handlers.stream_handler import StreamHandler
from banterbot.managers.stream_manager import StreamManager
from banterbot.models.message import Message
from banterbot.models.openai_model import OpenAIModel
from banterbot.models.stream_log_entry import StreamLogEntry
from banterbot.utils.nlp import NLP
[docs]
class OpenAIService:
"""
A class that handles the interaction with the OpenAI ChatCompletion API. It provides functionality to generate
responses from the API based on the input messages. It supports generating responses as a whole or as a stream of
response blocks.
The main purpose of this class is to facilitate the communication with the OpenAI API and handle the responses
generated by the API. It can be used to create chatbots or other applications that require natural language
processing and generation.
"""
api_key_set = False
client = None
[docs]
def __init__(self, model: OpenAIModel) -> None:
"""
Initializes an `OpenAIService` instance for a specific model.
Args:
model (OpenAIModel): The OpenAI model to be used. This should be an instance of the OpenAIModel class, which
contains information about the model, such as its name and maximum token limit.
"""
logging.debug(f"OpenAIService initialized")
# Set the OpenAI API key
if not self.__class__.api_key_set:
api_key = os.environ.get(EnvVar.OPENAI_API_KEY.value)
self.__class__.client = openai.OpenAI(api_key=api_key)
self.__class__.api_key_set = True
# The selected model that will be used in OpenAI ChatCompletion prompts.
self._model = model
# Indicates whether the current instance of `OpenAIService` is streaming.
self._streaming = False
# Set the interruption flag to zero: if interruptions are raised, this will be updated.
self._interrupt = 0
# Initialize the StreamManager for handling streaming processes.
self._stream_manager = StreamManager()
self._stream_manager.connect_processor(self._processor)
self._stream_manager.connect_completion_handler(self._completion_handler)
# A list of active stream handlers.
self._stream_handlers = []
self._stream_handlers_lock = threading.Lock()
[docs]
def interrupt(self, kill: bool = False) -> None:
"""
Interrupts the current OpenAI ChatCompletion process.
Args:
kill (bool): Whether the interruption should kill the queues or not.
"""
self._interrupt = time.perf_counter_ns()
with self._stream_handlers_lock:
for handler in self._stream_handlers:
handler.interrupt(kill=kill)
self._stream_handlers.clear()
logging.debug(f"OpenAIService Interrupted")
[docs]
def count_tokens(self, string: str) -> int:
"""
Counts the number of tokens in the provided string.
Args:
string (str): A string provided by the user where the number of tokens are to be counted.
Returns:
int: The number of tokens in the string.
"""
return len(self._model.tokenizer.encode(string))
[docs]
def prompt(self, messages: list[Message], split: bool = True, **kwargs) -> Union[tuple[str], str]:
"""
Sends messages to the OpenAI ChatCompletion API and retrieves the response as a list of sentences.
Args:
messages (list[Message]): A list of messages. Each message should be an instance of the `Message` class,
which contains the content and role (user or assistant) of the message.
split (bool): Whether the response should be split into sentences.
**kwargs: Additional parameters for the API request. These can include settings such as temperature, top_p,
and frequency_penalty.
Returns:
Union[list[str], str]: A list of sentences forming the response from the OpenAI API. This can be used to
display the generated response to the user or for further processing. If `split` is False, returns a string.
"""
response = self._request(messages=messages, stream=False, **kwargs)
logging.debug(f"OpenAIService stream processed block: `{response}`")
sentences = NLP.segment_sentences(response) if split else response
return sentences
[docs]
def prompt_stream(
self, messages: list[Message], init_time: Optional[int] = None, **kwargs
) -> Union[StreamHandler, tuple[()]]:
"""
Sends messages to the OpenAI API and retrieves the response as a stream of blocks of sentences.
Args:
messages (list[Message]): A list of messages. Each message should be an instance of the `Message` class,
which contains the content and role (user or assistant) of the message.
init_time (Optional[int]): The time at which the stream was initialized.
**kwargs: Additional parameters for the API request. These can include settings such as temperature, top_p,
and frequency_penalty.
Returns:
Union[StreamHandler, tuple[()]]: A handler for the stream of blocks of sentences forming the response from
the OpenAI API or an empty tuple if the stream was interrupted.
"""
# Record the time at which the stream was initialized pre-lock, in order to account for future interruptions.
init_time = time.perf_counter_ns() if init_time is None else init_time
if self._interrupt >= init_time:
return tuple()
else:
# Obtain a response from the OpenAI ChatCompletion API
stream = self._request(messages=messages, stream=True, **kwargs)
handler = self._stream_manager.stream(
iterable=stream,
# close_stream=stream.response.close,
init_shared_data={"text": "", "sentences": [], "init_time": init_time},
)
with self._stream_handlers_lock:
self._stream_handlers.append(handler)
return handler
@property
def model(self) -> OpenAIModel:
"""
Return the `OpenAIModel` associated with the current instance.
Returns:
OpenAIModel
"""
return self._model
def _processor(self, log: list[StreamLogEntry], index: int, shared_data: dict) -> list[str]:
"""
Parses a chunk of data from the OpenAI API response.
Args:
log (list[StreamLogEntry]): A list of `StreamLogEntry` instances containing the data from the OpenAI API
response.
index (int): The index of the current chunk of data.
shared_data (dict): A dictionary containing shared data between the stream handler and the processor.
Returns:
list[str]: A list of sentences parsed from the chunk.
"""
if shared_data["interrupt"] >= shared_data["init_time"]:
raise StopIteration
else:
if log[index].value.choices[0].delta.content is not None:
shared_data["text"] += log[index].value.choices[0].delta.content
shared_data["sentences"] = NLP.segment_sentences(shared_data["text"])
# If the current chunk is not the final chunk of data from the OpenAI API response, parse the chunk.
if len(shared_data["sentences"]) > 1:
shared_data["text"] = shared_data["sentences"][-1]
logging.debug(f"OpenAIService yielded sentences: {shared_data['sentences'][:-1]}")
return shared_data["sentences"][:-1]
def _completion_handler(self, log: list[StreamLogEntry], shared_data: dict) -> list[str]:
"""
Handles the completion of the OpenAI API response.
Args:
log (list[StreamLogEntry]): A list of `StreamLogEntry` instances containing the data from the OpenAI API
response.
shared_data (dict): A dictionary containing shared data between the stream handler and the processor.
Returns:
list[str]: A list of sentences parsed from the chunk.
"""
if shared_data["interrupt"] >= shared_data["init_time"]:
raise StopIteration
else:
# If the current chunk is the final chunk of data from the OpenAI API response, parse the final chunk.
shared_data["sentences"] = NLP.segment_sentences(shared_data["text"])
logging.debug(f"OpenAIService yielded final sentences: {shared_data['sentences'][:-1]}")
logging.debug("OpenAIService stream stopped")
return shared_data["sentences"]
def _request(self, messages: list[Message], stream: bool, **kwargs) -> Union[Iterator, str]:
"""
Sends a request to the OpenAI API and generates a response based on the specified parameters.
Args:
messages (list[Message]): A list of messages. Each message should be an instance of the `Message` class,
which contains the content and role (user or assistant) of the message.
stream (bool): Whether the response should be returned as an iterable stream or a complete text.
**kwargs: Additional parameters for the API request. These can include settings such as temperature, top_p,
and frequency_penalty.
Returns:
Union[Iterator, str]: The stream from the OpenAI API, either as a stream (Iterator) or text (str).
"""
kwargs["model"] = self._model.model
kwargs["n"] = 1
kwargs["stream"] = stream
kwargs["messages"] = [message() for message in messages]
success = False
for i in range(RETRY_LIMIT):
try:
response = self.__class__.client.chat.completions.create(**kwargs)
success = True
break
except openai.RateLimitError:
retry_timestamp = datetime.datetime.now() + datetime.timedelta(seconds=RETRY_TIME)
retry_timestamp = datetime.datetime.strftime(retry_timestamp, "%H:%M:%S")
error_message = (
f"OpenAIService encountered an OpenAI Rate Limiting Error - Attempt {i+1}/{RETRY_LIMIT}."
f" Waiting {RETRY_TIME} seconds until {retry_timestamp} to retry."
)
logging.info(error_message)
time.sleep(RETRY_TIME)
except openai.APIError:
retry_time = 0.25
retry_timestamp = datetime.datetime.now() + datetime.timedelta(seconds=retry_time)
retry_timestamp = retry_timestamp.strftime("%H:%M:%S")
error_message = (
f"OpenAIService encountered an OpenAI API Error - Attempt {i+1}/{RETRY_LIMIT}. Waiting "
f"{retry_time} seconds until {retry_timestamp} to retry."
)
logging.info(error_message)
time.sleep(retry_time)
if not success:
raise RuntimeError(f"OpenAIService encountered too many OpenAI API Errors; exiting program.")
return response if stream else response.choices[0].message.content.strip()