diff --git a/modules/git_webhooks/__init__.py b/modules/git_webhooks/__init__.py index fb0899c6..e75f4e39 100644 --- a/modules/git_webhooks/__init__.py +++ b/modules/git_webhooks/__init__.py @@ -34,20 +34,20 @@ class Module(ModuleManager.BaseModule): @utils.hook("api.post.github") def _api_github_webhook(self, event): return self._webhook("github", "GitHub", self._github, - event["data"], event["headers"]) + event["data"], event["headers"], event["params"]) @utils.hook("api.post.gitea") def _api_gitea_webhook(self, event): return self._webhook("gitea", "Gitea", self._gitea, - event["data"], event["headers"]) + event["data"], event["headers"], event["params"]) @utils.hook("api.post.gitlab") def _api_gitlab_webhook(self, event): return self._webhook("gitlab", "GitLab", self._gitlab, - event["data"], event["headers"]) + event["data"], event["headers"], event["params"]) def _webhook(self, webhook_type, webhook_name, handler, payload_str, - headers): + headers, params): payload = payload_str.decode("utf8") if headers["Content-Type"] == FORM_ENCODED: payload = urllib.parse.unquote(urllib.parse.parse_qs(payload)[ @@ -69,41 +69,41 @@ class Module(ModuleManager.BaseModule): branch = handler.branch(data, headers) current_event, event_action = handler.event(data, headers) - hooks = self.bot.database.channel_settings.find_by_setting( - "git-webhooks") + unfiltered_targets = [] + if "channels" in params: + channels = params["channels"].split(",") + for channel in params["channels"].split(","): + server, _, channel_name = channel.partition(":") + if server and channel_name: + server = self.bot.get_server_by_alias(server) + if server and channel_name in server.channels: + channel = server.channels.get(channel_name) + hooks = channel.get_setting("git-webhooks", {}) + + if hooks: + found_hook = self._find_hook( + full_name_lower, repo_username_lower, + organisation_lower, hooks) + + if found_hook: + unfiltered_targets.append([ + server, channel, found_hook]) + else: + unfiltered_targets = self._find_targets(full_name_lower, + repo_username_lower, organisation_lower) + + repo_hooked = bool(unfiltered_targets) targets = [] - repo_hooked = False + for server, channel, hook in unfiltered_targets: + events = [] + for hooked_event in hook["events"]: + events.append(handler.event_categories(hooked_event)) + events = list(itertools.chain(*events)) - for server_id, channel_name, hooked_repos in hooks: - hooked_repos_lower = {k.lower(): v for k, v in hooked_repos.items()} - found_hook = None - if full_name_lower and full_name_lower in hooked_repos_lower: - found_hook = hooked_repos_lower[full_name_lower] - elif repo_username_lower and repo_username_lower in hooked_repos_lower: - found_hook = hooked_repos_lower[repo_username_lower] - elif organisation_lower and organisation_lower in hooked_repos_lower: - found_hook = hooked_repos_lower[organisation_lower] - else: - continue - - repo_hooked = True - server = self.bot.get_server_by_id(server_id) - if server and channel_name in server.channels: - if (branch and - found_hook["branches"] and - not branch in found_hook["branches"]): - continue - - events = [] - for hooked_event in found_hook["events"]: - events.append(handler.event_categories(hooked_event)) - events = list(itertools.chain(*events)) - - channel = server.channels.get(channel_name) - if (current_event in events or - (event_action and event_action in events)): - targets.append([server, channel]) + if (current_event in events or + (event_action and event_action in events)): + targets.append([server, channel]) if not targets: if not repo_hooked: @@ -155,6 +155,34 @@ class Module(ModuleManager.BaseModule): return s + def _find_targets(self, full_name_lower, repo_username_lower, + organisation_lower): + hooks = self.bot.database.channel_settings.find_by_setting( + "git-webhooks") + targets = [] + for server_id, channel_name, hooked_repos in hooks: + found_hook = self._find_hook(full_name_lower, repo_username_lower, + organisation_lower, hooked_repos) + server = self.bot.get_server_by_id(server_id) + if found_hook and server and channel_name in server.channels: + channel = server.channels.get(channel_name) + targets.append([server, channel, found_hook]) + + return targets + + def _find_hook(self, full_name_lower, repo_username_lower, + organisation_lower, hooks): + hooked_repos_lower = {k.lower(): v for k, v in hooks.items()} + if full_name_lower and full_name_lower in hooked_repos_lower: + return hooked_repos_lower[full_name_lower] + elif (repo_username_lower and + repo_username_lower in hooked_repos_lower): + return hooked_repos_lower[repo_username_lower] + elif (organisation_lower and + organisation_lower in hooked_repos_lower): + return hooked_repos_lower[organisation_lower] + + @utils.hook("received.command.webhook", min_args=1, channel_only=True) def github_webhook(self, event): """