Remove cyclical references to IRCBot

This commit is contained in:
jesopo 2018-09-28 16:51:36 +01:00
parent 02a2b41246
commit a8bf3c9300
10 changed files with 145 additions and 142 deletions

View file

@ -33,8 +33,8 @@ class Module(object):
until_next_hour = 60-now.second until_next_hour = 60-now.second
until_next_hour += ((60-(now.minute+1))*60) until_next_hour += ((60-(now.minute+1))*60)
bot.add_timer("coin-interest", INTEREST_INTERVAL, persist=False, bot.timers.add("coin-interest", INTEREST_INTERVAL,
next_due=time.time()+until_next_hour) time.time()+until_next_hour)
@Utils.hook("received.command.coins") @Utils.hook("received.command.coins")
def coins(self, event): def coins(self, event):

View file

@ -11,8 +11,8 @@ class Module(object):
until_next_hour = 60-now.second until_next_hour = 60-now.second
until_next_hour += ((60-(now.minute+1))*60) until_next_hour += ((60-(now.minute+1))*60)
bot.add_timer("database-backup", BACKUP_INTERVAL, persist=False, bot.timers.add("database-backup", BACKUP_INTERVAL,
next_due=time.time()+until_next_hour) time.time()+until_next_hour)
@Utils.hook("timer.database-backup") @Utils.hook("timer.database-backup")
def backup(self, event): def backup(self, event):

View file

@ -16,10 +16,9 @@ class Module(ModuleManager.BaseModule):
if seconds <= SECONDS_MAX: if seconds <= SECONDS_MAX:
due_time = int(time.time())+seconds due_time = int(time.time())+seconds
self.bot.add_timer("in", seconds, self.bot.timers.add_persistent("in", seconds, due_time=due_time,
target=event["target"].name, due_time=due_time, target=event["target"].name, server_id=event["server"].id,
server_id=event["server"].id, nickname=event["user"].nickname, nickname=event["user"].nickname, message=message)
message=message)
event["stdout"].write("Saved") event["stdout"].write("Saved")
else: else:
event["stderr"].write( event["stderr"].write(

View file

@ -3,10 +3,20 @@ import configparser, os
class Config(object): class Config(object):
def __init__(self, location): def __init__(self, location):
self.location = location self.location = location
self._config = {}
self.load()
def load_config(self): def load(self):
if os.path.isfile(self.location): if os.path.isfile(self.location):
with open(self.location) as config_file: with open(self.location) as config_file:
parser = configparser.ConfigParser() parser = configparser.ConfigParser()
parser.read_string(config_file.read()) parser.read_string(config_file.read())
return dict(parser["bot"].items()) self._config = dict(parser["bot"].items())
def __getitem__(self, key):
return self._config[key]
def get(self, key, default=None):
return self._config.get(key, default)
def __contains__(self, key):
return key in self.config

View file

@ -1,24 +1,27 @@
import os, select, sys, threading, time, traceback, uuid import os, select, sys, threading, time, traceback, uuid
from . import EventManager, Exports, IRCLineHandler, IRCServer, Logging from . import EventManager, Exports, IRCLineHandler, IRCServer, Logging
from . import ModuleManager, Timer from . import ModuleManager
class Bot(object): class Bot(object):
def __init__(self): def __init__(self, args, config, database, events, exports, line_handler,
log, modules, timers):
self.args = args
self.config = config
self.database = database
self._events = events
self._exports = exports
self.line_handler = line_handler
self.log = log
self.modules = modules
self.timers = timers
events.on("timer.reconnect").hook(self.reconnect)
self.start_time = time.time() self.start_time = time.time()
self.lock = threading.Lock() self.lock = threading.Lock()
self.args = None
self.database = None
self.config = None
self.servers = {} self.servers = {}
self.running = True self.running = True
self.poll = select.epoll() self.poll = select.epoll()
self.timers = []
self._events = None
self._exports = None
self.modules = None
self.log = None
self.line_handler = None
def add_server(self, server_id, connect=True): def add_server(self, server_id, connect=True):
(_, alias, hostname, port, password, ipv4, tls, nickname, (_, alias, hostname, port, password, ipv4, tls, nickname,
@ -43,46 +46,7 @@ class Bot(object):
self.servers[server.fileno()] = server self.servers[server.fileno()] = server
self.poll.register(server.fileno(), select.EPOLLOUT) self.poll.register(server.fileno(), select.EPOLLOUT)
return True return True
def setup_timers(self, event):
for setting, value in self.find_settings("timer-%"):
id = setting.split("timer-", 1)[1]
self.add_timer(value["event-name"], value["delay"], value[
"next-due"], id, **value["kwargs"])
def timer_setting(self, timer):
self.set_setting("timer-%s" % timer.id, {
"event-name": timer.event_name, "delay": timer.delay,
"next-due": timer.next_due, "kwargs": timer.kwargs})
def timer_setting_remove(self, timer):
self.timers.remove(timer)
self.del_setting("timer-%s" % timer.id)
def add_timer(self, event_name, delay, next_due=None, id=None, persist=True,
**kwargs):
id = id or uuid.uuid4().hex
timer = Timer.Timer(id, self, self._events, event_name, delay,
next_due, **kwargs)
if id:
timer.id = id
elif persist:
self.timer_setting(timer)
self.timers.append(timer)
def next_timer(self):
next = None
for timer in self.timers:
time_left = timer.time_left()
if next == None or time_left < next:
next = time_left
if next == None:
return None
if next < 0:
return 0
return next
def call_timers(self):
for timer in self.timers[:]:
if timer.due():
timer.call()
if timer.done():
self.timer_setting_remove(timer)
def next_send(self): def next_send(self):
next = None next = None
for server in self.servers.values(): for server in self.servers.values():
@ -110,7 +74,7 @@ class Bot(object):
def get_poll_timeout(self): def get_poll_timeout(self):
timeouts = [] timeouts = []
timeouts.append(self.next_timer()) timeouts.append(self.timers.next())
timeouts.append(self.next_send()) timeouts.append(self.next_send())
timeouts.append(self.next_ping()) timeouts.append(self.next_ping())
timeouts.append(self.next_read_timeout()) timeouts.append(self.next_read_timeout())
@ -154,7 +118,8 @@ class Bot(object):
while self.running: while self.running:
self.lock.acquire() self.lock.acquire()
events = self.poll.poll(self.get_poll_timeout()) events = self.poll.poll(self.get_poll_timeout())
self.call_timers() self.timers.call()
for fd, event in events: for fd, event in events:
if fd in self.servers: if fd in self.servers:
server = self.servers[fd] server = self.servers[fd]
@ -185,7 +150,7 @@ class Bot(object):
self.disconnect(server) self.disconnect(server)
reconnect_delay = self.config.get("reconnect-delay", 10) reconnect_delay = self.config.get("reconnect-delay", 10)
self.add_timer("reconnect", reconnect_delay, None, None, False, self.timers.add("reconnect", reconnect_delay,
server_id=server.id) server_id=server.id)
print("disconnected from %s, reconnecting in %d seconds" % ( print("disconnected from %s, reconnecting in %d seconds" % (

View file

@ -14,9 +14,9 @@ CAPABILITIES = {"multi-prefix", "chghost", "invite-notify", "account-tag",
"batch", "draft/labeled-response"} "batch", "draft/labeled-response"}
class LineHandler(object): class LineHandler(object):
def __init__(self, bot, events): def __init__(self, events, timers):
self.bot = bot
self.events = events self.events = events
self.timers = timers
events.on("raw.PING").hook(self.ping) events.on("raw.PING").hook(self.ping)
events.on("raw.001").hook(self.handle_001, default_event=True) events.on("raw.001").hook(self.handle_001, default_event=True)
@ -570,10 +570,9 @@ class LineHandler(object):
# we need a registered nickname for this channel # we need a registered nickname for this channel
def handle_477(self, event): def handle_477(self, event):
channel_name = Utils.irc_lower(event["server"], event["args"][1]) channel_name = Utils.irc_lower(event["server"], event["args"][1])
if channel_name in event["server"].attempted_join: if channel_name in event["server"]:
self.bot.add_timer("rejoin", 5, key = event["server"].attempted_join[channel_name]
channel_name=event["args"][1], self.timers.add("rejoin", 5, channel_name=channe_name, key=key,
key=event["server"].attempted_join[channel_name],
server_id=event["server"].id) server_id=event["server"].id)
# someone's been kicked from a channel # someone's been kicked from a channel

View file

@ -27,13 +27,16 @@ class BaseModule(object):
self.exports = exports self.exports = exports
class ModuleManager(object): class ModuleManager(object):
def __init__(self, bot, events, exports, directory): def __init__(self, events, exports, config, log, directory):
self.bot = bot
self.events = events self.events = events
self.exports = exports self.exports = exports
self.config = config
self.log = log
self.directory = directory self.directory = directory
self.modules = {} self.modules = {}
self.waiting_requirement = {} self.waiting_requirement = {}
def list_modules(self): def list_modules(self):
return sorted(glob.glob(os.path.join(self.directory, "*.py"))) return sorted(glob.glob(os.path.join(self.directory, "*.py")))
@ -47,7 +50,7 @@ class ModuleManager(object):
def _get_magic(self, obj, magic, default): def _get_magic(self, obj, magic, default):
return getattr(obj, magic) if hasattr(obj, magic) else default return getattr(obj, magic) if hasattr(obj, magic) else default
def _load_module(self, name): def _load_module(self, bot, name):
path = self._module_path(name) path = self._module_path(name)
with io.open(path, mode="r", encoding="utf8") as module_file: with io.open(path, mode="r", encoding="utf8") as module_file:
@ -61,8 +64,7 @@ class ModuleManager(object):
raise ModuleNotLoadedWarning("module ignored") raise ModuleNotLoadedWarning("module ignored")
elif line_split[0] == "#--require-config" and len( elif line_split[0] == "#--require-config" and len(
line_split) > 1: line_split) > 1:
if not line_split[1].lower() in self.bot.config or not self.bot.config[ if not self.config.get(line_split[1].lower(), None):
line_split[1].lower()]:
# nope, required config option not present. # nope, required config option not present.
raise ModuleNotLoadedWarning( raise ModuleNotLoadedWarning(
"required config not present") "required config not present")
@ -88,8 +90,7 @@ class ModuleManager(object):
context = str(uuid.uuid4()) context = str(uuid.uuid4())
context_events = self.events.new_context(context) context_events = self.events.new_context(context)
context_exports = self.exports.new_context(context) context_exports = self.exports.new_context(context)
module_object = module.Module(self.bot, context_events, module_object = module.Module(bot, context_events, context_exports)
context_exports)
if not hasattr(module_object, "_name"): if not hasattr(module_object, "_name"):
module_object._name = name.title() module_object._name = name.title()
@ -109,29 +110,29 @@ class ModuleManager(object):
"attempted to be used twice") "attempted to be used twice")
return module_object return module_object
def load_module(self, name): def load_module(self, bot, name):
try: try:
module = self._load_module(name) module = self._load_module(bot, name)
except ModuleWarning as warning: except ModuleWarning as warning:
self.bot.log.error("Module '%s' not loaded", [name]) self.log.error("Module '%s' not loaded", [name])
raise raise
except Exception as e: except Exception as e:
self.bot.log.error("Failed to load module \"%s\": %s", self.log.error("Failed to load module \"%s\": %s",
[name, str(e)]) [name, str(e)])
raise raise
self.modules[module._import_name] = module self.modules[module._import_name] = module
if name in self.waiting_requirement: if name in self.waiting_requirement:
for requirement_name in self.waiting_requirement: for requirement_name in self.waiting_requirement:
self.load_module(requirement_name) self.load_module(bot, requirement_name)
self.bot.log.info("Module '%s' loaded", [name]) self.log.info("Module '%s' loaded", [name])
def load_modules(self, whitelist=[], blacklist=[]): def load_modules(self, bot, whitelist=[], blacklist=[]):
for path in self.list_modules(): for path in self.list_modules():
name = self._module_name(path) name = self._module_name(path)
if name in whitelist or (not whitelist and not name in blacklist): if name in whitelist or (not whitelist and not name in blacklist):
try: try:
self.load_module(name) self.load_module(bot, name)
except ModuleWarning: except ModuleWarning:
pass pass
@ -151,5 +152,5 @@ class ModuleManager(object):
references -= 1 # 'del module' removes one reference references -= 1 # 'del module' removes one reference
references -= 1 # one of the refs is from getrefcount references -= 1 # one of the refs is from getrefcount
self.bot.log.info("Module '%s' unloaded (%d reference%s)", self.log.info("Module '%s' unloaded (%d reference%s)",
[name, references, "" if references == 1 else "s"]) [name, references, "" if references == 1 else "s"])

View file

@ -1,39 +0,0 @@
import time, uuid
class Timer(object):
def __init__(self, id, bot, events, event_name, delay,
next_due=None, **kwargs):
self.id = id
self.bot = bot
self.events = events
self.event_name = event_name
self.delay = delay
if next_due:
self.next_due = next_due
else:
self.set_next_due()
self.kwargs = kwargs
self._done = False
self.call_count = 0
def set_next_due(self):
self.next_due = time.time()+self.delay
def due(self):
return self.time_left() <= 0
def time_left(self):
return self.next_due-time.time()
def call(self):
self._done = True
self.call_count +=1
self.events.on("timer").on(self.event_name).call(
timer=self, **self.kwargs)
def redo(self):
self._done = False
self.set_next_due()
def done(self):
return self._done

73
src/Timers.py Normal file
View file

@ -0,0 +1,73 @@
import time, uuid
class Timer(object):
def __init__(self, id, name, delay, next_due, kwargs):
self.id = id
self.name = name
self.delay = delay
if next_due:
self.next_due = next_due
else:
self.set_next_due()
self.kwargs = kwargs
self._done = False
def set_next_due(self):
self.next_due = time.time()+self.delay
def due(self):
return self.time_left() <= 0
def time_left(self):
return self.next_due-time.time()
def redo(self):
self._done = False
self.set_next_due()
def finish():
self._done = True
def done(self):
return self._done
class Timers(object):
def __init__(self, events, log):
self.events = events
self.log = log
self.timers = []
def setup(self, timers):
for name, timer in timers:
id = name.split("timer-", 1)[1]
self._add(timer["name"], timer["delay"], timer[
"next-due"], id, False, timer["kwargs"])
def _persist(self, timer):
self.set_setting("timer-%s" % timer.id, {
"name": timer.name, "delay": timer.delay,
"next-due": timer.next_due, "kwargs": timer.kwargs})
def _remove(self, timer):
self.timers.remove(timer)
self.del_setting("timer-%s" % timer.id)
def add(self, name, delay, next_due=None, **kwargs):
self._add(name, delay, next_due, None, False, kwargs)
def add_persistent(self, name, delay, next_due=None, **kwargs):
self._add(name, delay, next_due, None, True, kwargs)
def _add(self, name, delay, next_due, id, persist, kwargs):
id = id or uuid.uuid4().hex
timer = Timer(id, name, delay, next_due, kwargs)
if persist:
self._persist(timer)
self.timers.append(timer)
def next(self):
times = filter(None, [timer.time_left() for timer in self.timers])
if not times:
return None
return max(min(times), 0)
def call(self):
for timer in self.timers[:]:
if timer.due():
timer.finish()
self.events.on("timer.%s" % timer.name, timer=timer)
if timer.done():
self._remove(timer)

View file

@ -2,7 +2,7 @@
import argparse, os, sys, time import argparse, os, sys, time
from src import Config, Database, EventManager, Exports, IRCBot from src import Config, Database, EventManager, Exports, IRCBot
from src import IRCLineHandler, Logging, ModuleManager from src import IRCLineHandler, Logging, ModuleManager, Timers
def bool_input(s): def bool_input(s):
result = input("%s (Y/n): " % s) result = input("%s (Y/n): " % s)
@ -29,31 +29,23 @@ arg_parser.add_argument("--verbose", "-v", action="store_true")
args = arg_parser.parse_args() args = arg_parser.parse_args()
log = Logging.Log(args.log) log = Logging.Log(args.log)
config = Config.Config(args.config).load_config() config = Config.Config(args.config)
database = Database.Database(log, args.database) database = Database.Database(log, args.database)
events = events = EventManager.EventHook(log) events = events = EventManager.EventHook(log)
exports = exports = Exports.Exports() exports = exports = Exports.Exports()
timers = Timers.Timers(events, log)
bot = IRCBot.Bot() line_handler = IRCLineHandler.LineHandler(events, timers)
modules = modules = ModuleManager.ModuleManager(events, exports, config, log,
bot.modules = modules = ModuleManager.ModuleManager(bot, events, exports,
os.path.join(directory, "modules")) os.path.join(directory, "modules"))
bot.line_handler = IRCLineHandler.LineHandler(bot, events)
bot.log = log bot = IRCBot.Bot(args, config, database, events, exports, line_handler, log,
bot.config = config modules, timers)
bot.database = database
bot._events = events
bot._exports = exports
bot.args = args
bot._events.on("timer.reconnect").hook(bot.reconnect)
bot._events.on("boot.done").hook(bot.setup_timers)
whitelist = bot.get_setting("module-whitelist", []) whitelist = bot.get_setting("module-whitelist", [])
blacklist = bot.get_setting("module-blacklist", []) blacklist = bot.get_setting("module-blacklist", [])
bot.modules.load_modules(whitelist=whitelist, blacklist=blacklist) modules.load_modules(bot, whitelist=whitelist, blacklist=blacklist)
servers = [] servers = []
for server_id, alias in bot.database.servers.get_all(): for server_id, alias in bot.database.servers.get_all():
@ -62,6 +54,9 @@ for server_id, alias in bot.database.servers.get_all():
servers.append(server) servers.append(server)
if len(servers): if len(servers):
bot._events.on("boot.done").call() bot._events.on("boot.done").call()
bot.timers.setup(bot.find_settings_prefix("timer-"))
for server in servers: for server in servers:
if not bot.connect(server): if not bot.connect(server):
sys.stderr.write("failed to connect to '%s', exiting\r\n" % ( sys.stderr.write("failed to connect to '%s', exiting\r\n" % (