diff --git a/src/IRCBot.py b/src/IRCBot.py index df0b0540..f9493367 100644 --- a/src/IRCBot.py +++ b/src/IRCBot.py @@ -36,7 +36,7 @@ class ListLambdaPollHook(PollHook.PollHook): class Bot(object): def __init__(self, directory, args, cache, config, database, events, - exports, log, modules, timers): + exports, log, modules, timers, lock_file): self.directory = directory self.args = args self.cache = cache @@ -71,6 +71,7 @@ class Bot(object): self._poll_timeouts = [] # typing.List[PollHook] self._poll_timeouts.append(self._timers) self._poll_timeouts.append(self.cache) + self._poll_timeouts.append(lock_file) self._poll_timeouts.append(ListLambdaPollHook( lambda: self.servers.values(), diff --git a/src/LockFile.py b/src/LockFile.py new file mode 100644 index 00000000..9f53b8d5 --- /dev/null +++ b/src/LockFile.py @@ -0,0 +1,40 @@ +import datetime, os +from src import PollHook, utils + +EXPIRATION = 60 # 1 minute + +class LockFile(PollHook.PollHook): + def __init__(self, database_location: str): + self._database_location = database_location + self._lock_location = "%s.lock" % database_location + self._next_lock = None + + def available(self): + now = utils.datetime_utcnow() + if os.path.exists(self._lock_location): + with open(self._lock_location, "r") as lock_file: + timestamp_str = lock_file.read().strip().split(" ", 1)[0] + + timestamp = utils.iso8601_parse(timestamp_str) + + if (now-timestamp).total_seconds() < EXPIRATION: + return False + + return True + + def lock(self): + with open(self._lock_location, "w") as lock_file: + last_lock = utils.datetime_utcnow() + lock_file.write("%s" % utils.iso8601_format(last_lock)) + self._next_lock = last_lock+datetime.timedelta( + seconds=EXPIRATION/2) + + def next(self): + return max(0, (self._next_lock-utils.datetime_utcnow()).total_seconds()) + def call(self): + if self.next() == 0: + self.lock() + + def unlock(self): + if os.path.isfile(self._lock_location): + os.remove(self._lock_location) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index eaf3cc03..d36112a5 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -15,6 +15,9 @@ ISO8601_FORMAT_TZ = "%z" DATETIME_HUMAN = "%Y/%m/%d %H:%M:%S" +def datetime_utcnow() -> datetime.datetime: + return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) + def iso8601_format(dt: datetime.datetime, milliseconds: bool=False) -> str: dt_format = dt.strftime(ISO8601_FORMAT_DT) tz_format = dt.strftime(ISO8601_FORMAT_TZ) @@ -25,8 +28,7 @@ def iso8601_format(dt: datetime.datetime, milliseconds: bool=False) -> str: return "%s%s%s" % (dt_format, ms_format, tz_format) def iso8601_format_now(milliseconds: bool=False) -> str: - now = datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc) - return iso8601_format(now, milliseconds=milliseconds) + return iso8601_format(datetime_utcnow(), milliseconds=milliseconds) def iso8601_parse(s: str, microseconds: bool=False) -> datetime.datetime: fmt = ISO8601_PARSE_MICROSECONDS if microseconds else ISO8601_PARSE return datetime.datetime.strptime(s, fmt) diff --git a/start.py b/start.py index 47b703ec..9c000521 100755 --- a/start.py +++ b/start.py @@ -6,9 +6,9 @@ if sys.version_info < (3, 6): sys.stderr.write("BitBot requires python 3.6.0 or later\n") sys.exit(1) -import argparse, faulthandler, os, platform, time +import atexit, argparse, faulthandler, os, platform, time from src import Cache, Config, Database, EventManager, Exports, IRCBot -from src import Logging, ModuleManager, Timers, utils +from src import LockFile, Logging, ModuleManager, Timers, utils faulthandler.enable() @@ -65,6 +65,14 @@ log = Logging.Log(not args.no_logging, log_level, args.log_dir) log.info("Starting BitBot %s (Python v%s)", [IRCBot.VERSION, platform.python_version()]) +lock_file = LockFile.LockFile(args.database) +if not lock_file.available(): + log.critical("Database is locked. Is BitBot already running?") + sys.exit(1) + +atexit.register(lock_file.unlock) +lock_file.lock() + database = Database.Database(log, args.database) if args.remove_server: @@ -98,7 +106,7 @@ modules = ModuleManager.ModuleManager(events, exports, timers, config, log, module_directories) bot = IRCBot.Bot(directory, args, cache, config, database, events, - exports, log, modules, timers) + exports, log, modules, timers, lock_file) if args.module: definition = modules.find_module(args.module)