From 8b909fc4114016629a734bee6f5eaf3697c7c1e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 9 Mar 2021 11:47:06 +0200 Subject: [PATCH 01/43] Update github repo syntax. Fixes #11 --- github/commands.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/github/commands.py b/github/commands.py index a62f915..005cf3a 100644 --- a/github/commands.py +++ b/github/commands.py @@ -68,7 +68,7 @@ async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: return decorator -repo_syntax = r"([A-Za-z0-9-_]+)/([A-Za-z0-9-_]+)" +repo_syntax = r"([A-Za-z0-9-]+)/([A-Za-z0-9-_.]+)" class Commands: From e2d98f16fdcb98c8165fa45204f236a986329e1b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 27 May 2021 15:08:00 +0300 Subject: [PATCH 02/43] Add missing space between label and 'to' --- base-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base-config.yaml b/base-config.yaml index b6dadfc..60b1965 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -34,7 +34,7 @@ templates: label_aggregation: > {% if aggregation.added_labels %} added {{ util.join_human_list(aggregation.added_labels, mutate=fancy_label) }} - {% if not aggregation.removed_labels %}to{% endif %} + {% if not aggregation.removed_labels %} to{% endif %} {% endif %} {% if aggregation.removed_labels %} {% if aggregation.added_labels %} From f047459ed2b54e38de4a7cc60f528056043004ed Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 12 Jun 2021 13:31:08 +0300 Subject: [PATCH 03/43] Fix excluding unnecessary links from plaintext --- base-config.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/base-config.yaml b/base-config.yaml index 60b1965..5fdc8eb 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -72,12 +72,12 @@ macros: > {%- endmacro -%} {%- macro user_link(user) -%} - {{ (user.name or user.login)|e }} + {{ (user.name or user.login)|e }} {%- endmacro -%} {%- macro commit_user_link(user) -%} {% if user.username %} - {{ user.name|e }} + {{ user.name|e }} {% else %} {{ user.name|e }} {% endif %} @@ -88,7 +88,7 @@ macros: > {%- endmacro -%} {%- macro repo_link(repo, important=True) -%} - {{ repo.full_name|e }} + {{ repo.full_name|e }} {%- endmacro -%} {%- macro personal_link(user, self_text=None, possessive=False, self=sender) -%} @@ -161,7 +161,7 @@ messages: {{ fancy_labels(pull_request.labels) }} {% elif action == CLOSED %} {% if pull_request.merged_at %} - merged + merged {% else %} closed {% endif %} @@ -272,7 +272,7 @@ messages: {% endif %} {% macro page_link(page) %} - {{ page.title }} + {{ page.title }} (diff) {% endmacro %} From 3764904739019917e29adbfb6bf79376eed5cc28 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 18 Jun 2021 21:05:19 +0300 Subject: [PATCH 04/43] Remove SerializableAttrs type param --- github/api/types.py | 98 ++++++++++++++++++++++----------------------- 1 file changed, 49 insertions(+), 49 deletions(-) diff --git a/github/api/types.py b/github/api/types.py index 399fccf..b6c4f9f 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -39,7 +39,7 @@ def datetime_deserializer(data: JSON) -> HubDateTime: @dataclass -class User(SerializableAttrs['User']): +class User(SerializableAttrs): login: str id: int node_id: str @@ -64,7 +64,7 @@ class User(SerializableAttrs['User']): @dataclass -class Organization(SerializableAttrs['Organization']): +class Organization(SerializableAttrs): id: int node_id: str login: str @@ -81,14 +81,14 @@ class Organization(SerializableAttrs['Organization']): @dataclass -class GitUser(SerializableAttrs['GitUser']): +class GitUser(SerializableAttrs): name: str email: str username: Optional[str] = None @dataclass -class License(SerializableAttrs['License']): +class License(SerializableAttrs): key: str name: str spdx_id: str @@ -97,7 +97,7 @@ class License(SerializableAttrs['License']): @dataclass -class Repository(SerializableAttrs['Repository']): +class Repository(SerializableAttrs): id: int node_id: str name: str @@ -181,7 +181,7 @@ def meta(self) -> Dict[str, Any]: @dataclass -class Commit(SerializableAttrs['Commit']): +class Commit(SerializableAttrs): id: str tree_id: str distinct: bool @@ -196,7 +196,7 @@ class Commit(SerializableAttrs['Commit']): @dataclass -class PushEvent(SerializableAttrs['PushEvent']): +class PushEvent(SerializableAttrs): ref: str before: str after: str @@ -217,7 +217,7 @@ class PushEvent(SerializableAttrs['PushEvent']): @dataclass -class ReleaseAsset(SerializableAttrs['ReleaseAsset']): +class ReleaseAsset(SerializableAttrs): id: int node_id: int url: str @@ -234,7 +234,7 @@ class ReleaseAsset(SerializableAttrs['ReleaseAsset']): @dataclass -class Release(SerializableAttrs['Release']): +class Release(SerializableAttrs): id: int node_id: str tag_name: str @@ -268,7 +268,7 @@ class ReleaseAction(SerializableEnum): @dataclass -class ReleaseEvent(SerializableAttrs['ReleaseEvent']): +class ReleaseEvent(SerializableAttrs): action: ReleaseAction release: Release repository: Repository @@ -281,7 +281,7 @@ class StarAction(SerializableEnum): @dataclass -class StarEvent(SerializableAttrs['StarEvent']): +class StarEvent(SerializableAttrs): action: StarAction starred_at: HubDateTime repository: Repository @@ -293,21 +293,21 @@ class WatchAction(SerializableEnum): @dataclass -class WatchEvent(SerializableAttrs['StarEvent']): +class WatchEvent(SerializableAttrs): action: WatchAction repository: Repository sender: User @dataclass -class ForkEvent(SerializableAttrs['ForkEvent']): +class ForkEvent(SerializableAttrs): forkee: Repository repository: Repository sender: User @dataclass -class Label(SerializableAttrs['Label']): +class Label(SerializableAttrs): id: int node_id: str url: str @@ -322,7 +322,7 @@ class IssueState(SerializableEnum): @dataclass -class Milestone(SerializableAttrs['Milestone']): +class Milestone(SerializableAttrs): id: int node_id: str number: int @@ -344,7 +344,7 @@ class Milestone(SerializableAttrs['Milestone']): @dataclass -class IssuePullURLs(SerializableAttrs['IssuePullURLs']): +class IssuePullURLs(SerializableAttrs): diff_url: str html_url: str patch_url: str @@ -352,7 +352,7 @@ class IssuePullURLs(SerializableAttrs['IssuePullURLs']): @dataclass -class Issue(SerializableAttrs['Issue']): +class Issue(SerializableAttrs): id: int node_id: str number: int @@ -413,18 +413,18 @@ class IssueAction(SerializableEnum): @dataclass -class Change(SerializableAttrs['Change']): +class Change(SerializableAttrs): original: str = attr.ib(metadata={"json": "from"}) @dataclass -class IssueChanges(SerializableAttrs['IssueChanges']): +class IssueChanges(SerializableAttrs): body: Optional[Change] = None title: Optional[Change] = None @dataclass -class IssuesEvent(SerializableAttrs['IssuesEvent']): +class IssuesEvent(SerializableAttrs): action: IssueAction issue: Issue repository: Repository @@ -443,7 +443,7 @@ def meta(self) -> Dict[str, Any]: @dataclass -class IssueComment(SerializableAttrs['IssueComment']): +class IssueComment(SerializableAttrs): id: int node_id: int url: str @@ -467,7 +467,7 @@ class CommentAction(SerializableEnum): @dataclass -class IssueCommentEvent(SerializableAttrs['IssueCommentEvent']): +class IssueCommentEvent(SerializableAttrs): action: CommentAction issue: Issue comment: IssueComment @@ -484,14 +484,14 @@ def meta(self) -> Dict[str, Any]: @dataclass -class WebhookResponse(SerializableAttrs['WebhookResponse']): +class WebhookResponse(SerializableAttrs): code: Optional[int] status: str message: Optional[str] @dataclass -class WebhookConfig(SerializableAttrs['WebhookConfig']): +class WebhookConfig(SerializableAttrs): url: str content_type: Optional[str] = None secret: Optional[str] = None @@ -499,7 +499,7 @@ class WebhookConfig(SerializableAttrs['WebhookConfig']): @dataclass -class Webhook(SerializableAttrs['Webhook']): +class Webhook(SerializableAttrs): id: int type: str name: str @@ -515,14 +515,14 @@ class Webhook(SerializableAttrs['Webhook']): @dataclass -class PingEvent(SerializableAttrs['PingEvent']): +class PingEvent(SerializableAttrs): zen: str hook_id: int hook: Webhook @dataclass -class CreateEvent(SerializableAttrs['CreateEvent']): +class CreateEvent(SerializableAttrs): ref_type: str ref: str master_branch: str @@ -533,7 +533,7 @@ class CreateEvent(SerializableAttrs['CreateEvent']): @dataclass -class DeleteEvent(SerializableAttrs['DeleteEvent']): +class DeleteEvent(SerializableAttrs): ref_type: str ref: str pusher_type: str @@ -546,7 +546,7 @@ class MetaAction(SerializableEnum): @dataclass -class MetaEvent(SerializableAttrs['MetaEvent']): +class MetaEvent(SerializableAttrs): action: MetaAction hook: Webhook hook_id: int @@ -555,7 +555,7 @@ class MetaEvent(SerializableAttrs['MetaEvent']): @dataclass -class CommitComment(SerializableAttrs['CommitComment']): +class CommitComment(SerializableAttrs): id: int node_id: str user: User @@ -581,7 +581,7 @@ def meta(self) -> Dict[str, Any]: @dataclass -class CommitCommentEvent(SerializableAttrs['CommitCommentEvent']): +class CommitCommentEvent(SerializableAttrs): action: CommentAction comment: CommitComment repository: Repository @@ -596,7 +596,7 @@ def meta(self) -> Dict[str, Any]: @dataclass -class MilestoneChanges(SerializableAttrs['MilestoneChanges']): +class MilestoneChanges(SerializableAttrs): title: Optional[Change] = None description: Optional[Change] = None due_on: Optional[Change] = None @@ -611,7 +611,7 @@ class MilestoneAction(SerializableEnum): @dataclass -class MilestoneEvent(SerializableAttrs['MilestoneEvent']): +class MilestoneEvent(SerializableAttrs): action: MilestoneAction milestone: Milestone repository: Repository @@ -626,13 +626,13 @@ class LabelAction(SerializableEnum): @dataclass -class LabelChanges(SerializableAttrs['LabelChanges']): +class LabelChanges(SerializableAttrs): name: Optional[Change] = None color: Optional[Change] = None @dataclass -class LabelEvent(SerializableAttrs['LabelEvent']): +class LabelEvent(SerializableAttrs): action: LabelAction label: Label changes: LabelChanges @@ -646,7 +646,7 @@ class WikiPageAction(SerializableEnum): @dataclass -class WikiPageEvent(SerializableAttrs['WikiPageEvent']): +class WikiPageEvent(SerializableAttrs): action: WikiPageAction page_name: str title: str @@ -656,14 +656,14 @@ class WikiPageEvent(SerializableAttrs['WikiPageEvent']): @dataclass -class WikiEvent(SerializableAttrs['WikiEvent']): +class WikiEvent(SerializableAttrs): pages: List[WikiPageEvent] repository: Repository sender: User @dataclass -class PublicEvent(SerializableAttrs['PublicEvent']): +class PublicEvent(SerializableAttrs): repository: Repository sender: User @@ -674,7 +674,7 @@ class PullRequestState(SerializableEnum): @dataclass -class PullRequestRef(SerializableAttrs['PullRequestRef']): +class PullRequestRef(SerializableAttrs): label: str ref: str sha: str @@ -694,7 +694,7 @@ class TeamPermission(SerializableEnum): @dataclass -class Team(SerializableAttrs['Team']): +class Team(SerializableAttrs): id: int node_id: str name: str @@ -710,7 +710,7 @@ class Team(SerializableAttrs['Team']): @dataclass -class PartialPullRequest(SerializableAttrs['PartialPullRequest']): +class PartialPullRequest(SerializableAttrs): id: int node_id: str number: int @@ -756,7 +756,7 @@ def meta(self) -> Dict[str, Any]: @dataclass -class PullRequest(PartialPullRequest, SerializableAttrs['PullRequest']): +class PullRequest(PartialPullRequest, SerializableAttrs): merged_by: Optional[User] draft: bool @@ -794,7 +794,7 @@ class PullRequestAction(SerializableEnum): @dataclass -class PullRequestEvent(SerializableAttrs['PullRequestEvent']): +class PullRequestEvent(SerializableAttrs): action: PullRequestAction pull_request: PullRequest number: int @@ -820,7 +820,7 @@ class ReviewState(SerializableEnum): @dataclass -class Review(SerializableAttrs['Review']): +class Review(SerializableAttrs): id: int node_id: str user: User @@ -834,12 +834,12 @@ class Review(SerializableAttrs['Review']): @dataclass -class ReviewChanges(SerializableAttrs['ReviewChanges']): +class ReviewChanges(SerializableAttrs): body: Optional[Change] = None @dataclass -class PullRequestReviewEvent(SerializableAttrs['PullRequestReviewEvent']): +class PullRequestReviewEvent(SerializableAttrs): action: PullRequestReviewAction pull_request: PartialPullRequest review: Review @@ -855,7 +855,7 @@ class PullRequestReviewCommentAction(SerializableEnum): @dataclass -class ReviewComment(SerializableAttrs['ReviewComment']): +class ReviewComment(SerializableAttrs): id: int node_id: str pull_request_review_id: int @@ -885,7 +885,7 @@ def meta(self) -> Dict[str, Any]: @dataclass -class PullRequestReviewCommentEvent(SerializableAttrs['PullRequestReviewCommentEvent']): +class PullRequestReviewCommentEvent(SerializableAttrs): action: PullRequestReviewCommentAction pull_request: PartialPullRequest comment: ReviewComment @@ -915,7 +915,7 @@ class RepositoryAction(SerializableEnum): @dataclass -class RepositoryEvent(SerializableAttrs['RepositoryEvent']): +class RepositoryEvent(SerializableAttrs): action: RepositoryAction repository: Repository sender: User From 6cb715e1ed6911f4da4e9358694a565e9fd74106 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 2 Aug 2021 11:38:38 +0300 Subject: [PATCH 05/43] Fix pull request merged event --- base-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base-config.yaml b/base-config.yaml index 5fdc8eb..9ff58a8 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -161,7 +161,7 @@ messages: {{ fancy_labels(pull_request.labels) }} {% elif action == CLOSED %} {% if pull_request.merged_at %} - merged + merged {% else %} closed {% endif %} From 560518d0803218a59b5528b0d3aa388754505a0c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 11:07:43 +0300 Subject: [PATCH 06/43] Check that evt.label is set in issue aggregation --- github/webhook/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 20431f6..017a1d2 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -147,7 +147,7 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: elif evt_type != self.event_type: return False elif self.event_type in (EventType.ISSUES, EventType.PULL_REQUEST): - if self.event.action == self.action_type.OPENED and evt.label.id in self._label_ids: + if self.event.action == self.action_type.OPENED and evt.label and evt.label.id in self._label_ids: # Label was already in original event, drop the message. pass elif self.event.action == self.action_type.X_LABEL_AGGREGATE: From d66b13a42a68eb660aa92ed9f437d1947f4044cf Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 11:55:32 +0300 Subject: [PATCH 07/43] Add option to not ask for repo access --- README.md | 47 +++++++++++++++++++++++++++++++++------------- github/commands.py | 2 ++ 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index fc0de16..305d9a9 100644 --- a/README.md +++ b/README.md @@ -10,15 +10,18 @@ See steps 4 and 5 below. ## Basic setup 1. **Set up the plugin like any other maubot plugin.** - You just have to upload the plugin, and then create an instance i.e. an association of a plugin and a client. - + You just have to upload the plugin, and then create an instance i.e. an + association of a plugin and a client. + You have to give this new instance an `instance_id` / a name, for example "my_github_bot" 2. **[Register a GitHub OAuth application](https://github.com/settings/developers) to get a `client_id` and `client_secret`.** Set the callback URL to `https://{maubot_host}/{plugin_base_path}/{instance_id}/auth` - Following our example, if your instance is hosted on `maubot.example.com` and you kept the default `plugin_base_path` i.e. `_matrix/maubot/plugin`, the Github's new oAuth App's form should go like this: + Following our example, if your instance is hosted on `maubot.example.com` + and you kept the default `plugin_base_path` i.e. `_matrix/maubot/plugin`, + the Github's new OAuth App's form should go like this: * Application name: My Github Bot * Homepage URL: https://maubot.example.com/ @@ -27,26 +30,44 @@ See steps 4 and 5 below. 3. **Set the `client_id` and `client_secret` in maubot.** - Copy these informations from your Github's oAuth App page and paste them in the instance page options. + Copy these informations from your Github's OAuth App page and paste them in + the instance page options. ``` client_id: client_secret: ``` - + And save the instance configuration. 4. **Use `!github login` to log in.** - After inviting your bot / client to a matrix channel, use the `!gh` or `!github` command to use the github instance. - + After inviting your bot / client to a matrix channel, use the `!gh` or + `!github` command to use the github instance. + Using `gh login` first is mandatory and needed once **per instance**. - - The bot will reply with a link leading to your personal Github's allowed oAuth apps page, where you shall grant the necessary rights to the bot oAuth app. + + The bot will reply with a link leading to your personal Github's allowed + OAuth apps page, where you shall grant the necessary rights to the bot + OAuth app. + + By default, the bot will request access to all public repos and to add + webhooks. You can control the permissions it wants with some flags: + + * `--no-repo` makes it not ask for repo access at all. + Only `!github webhook add` will work, other commands like `!github create` + will not. + * `--no-hook` makes it not ask for webhook access. + `!github webhook add` will not work. + * `--private` makes it ask for private repo access. Necessary if you want to + use the bot to manage private repos. 5. **Use `!github webhook add /` to add webhooks.** - This will let you see in the current channel all the commits, comments, issues, stars, forks, pull requests, and so on, for that given repository. - - You must have admin rights on the repositories you want to track, as adding webhooks to a repository requires manager access rights to a project. + This will let you see in the current channel all the commits, comments, + issues, stars, forks, pull requests, and so on, for that given repository. + + You must have admin rights on the repositories you want to track, as adding + webhooks to a repository requires manager access rights to a project. - Once you create a webhook and track a repository, it will be tracked **only in the room from which you are in**. + Once you create a webhook and track a repository, it will be tracked + **only in the room from which you are in**. diff --git a/github/commands.py b/github/commands.py index 005cf3a..47fdb17 100644 --- a/github/commands.py +++ b/github/commands.py @@ -106,6 +106,8 @@ async def login(self, evt: MessageEvent, flags: str, client: Optional[GitHubClie redirect_url = (self.bot.webapp_url / "auth").with_query({"user_id": evt.sender}) flags = flags.lower() scopes = ["user:user", "public_repo", "admin:repo_hook"] + if "--no-repo" in flags: + scopes.remove("public_repo") if "--no-hook" in flags: scopes.remove("admin:repo_hook") if "--private" in flags: From b085d8e3cc9208b20e4e82e8272a0e2e2f70c3b6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 12:15:40 +0300 Subject: [PATCH 08/43] Handle GraphQL errors from GitHub properly --- github/api/__init__.py | 2 +- github/api/client.py | 21 ++++++++++++++++++--- github/commands.py | 11 +++++++++-- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/github/api/__init__.py b/github/api/__init__.py index 303797f..91a434d 100644 --- a/github/api/__init__.py +++ b/github/api/__init__.py @@ -1,2 +1,2 @@ -from .client import GitHubClient, GitHubError +from .client import GitHubClient, GitHubError, GraphQLError from .webhook import GitHubWebhookReceiver diff --git a/github/api/client.py b/github/api/client.py index 8925c38..7de2b56 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -35,6 +35,13 @@ def __init__(self, message: str, documentation_url: str, status: int, **kwargs) self.kwargs = kwargs +class GraphQLError(Exception): + def __init__(self, type: str, message: str, **kwargs) -> None: + super().__init__(message) + self.type = type + self.kwargs = kwargs + + class GitHubClient: base_url: URL = URL("https://api.github.com") api_url: URL = base_url / "graphql" @@ -94,10 +101,18 @@ async def call(self, query_type: str, query: str, args: str, variables: Optional full_query += f" ({args})" full_query += " {%s}" % query resp = await self.call_raw(full_query, variables) - print(resp) + try: + error = resp["errors"][0] + raise GraphQLError(**error) + except (KeyError, IndexError): + try: + data = resp["data"] + except KeyError: + raise GraphQLError(type="UNKNOWN_ERROR", + message="Unknown error: GitHub didn't return any data") if path: - return recursive_get(resp["data"], path) - return resp["data"] + return recursive_get(data, path) + return data @property def headers(self) -> Dict[str, str]: diff --git a/github/commands.py b/github/commands.py index 47fdb17..c4f00f6 100644 --- a/github/commands.py +++ b/github/commands.py @@ -20,7 +20,7 @@ from maubot.handlers import command, event from mautrix.types import EventType, Event, ReactionEvent, RelationType -from .api import GitHubClient, GitHubError +from .api import GitHubClient, GitHubError, GraphQLError if TYPE_CHECKING: from .bot import GitHubBot @@ -36,7 +36,14 @@ async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: return elif client and not client.token: client = None - return await fn(self, evt, **kwargs, client=client) + try: + return await fn(self, evt, **kwargs, client=client) + except GraphQLError as e: + if e.type == "INSUFFICIENT_SCOPES": + await evt.reply("Your login doesn't have sufficient access to do that. " + "Try adding more permissions with `!github login`.") + else: + await evt.reply(str(e)) return wrapper From 660853d7947e675beca04316dc3037c8a11c19e2 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 12:17:20 +0300 Subject: [PATCH 09/43] Remove remaining debug prints The other one was removed in b085d8e3cc9208b20e4e82e8272a0e2e2f70c3b6 Closes #13 --- github/api/client.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/github/api/client.py b/github/api/client.py index 7de2b56..af4b85b 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -157,8 +157,6 @@ async def create_webhook(self, owner: str, repo: str, url: URL, *, active: bool "events": events or ["push"], "active": active, } - print(self.base_url / "repos" / owner / repo / "hooks") - print(payload) resp = await self.http.post(self.base_url / "repos" / owner / repo / "hooks", data=json.dumps(payload), headers=self.headers) data = await resp.json() From 7a467edba3d876648851b52a40240926ad09174d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 12:21:35 +0300 Subject: [PATCH 10/43] Add support for reacting to pull requests. Fixes #10 --- github/api/types.py | 7 +++++++ github/commands.py | 2 ++ 2 files changed, 9 insertions(+) diff --git a/github/api/types.py b/github/api/types.py index b6c4f9f..87723a9 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -806,6 +806,13 @@ class PullRequestEvent(SerializableAttrs): milestone: Optional[Milestone] = None requested_reviewer: Optional[User] = None + def meta(self) -> Dict[str, Any]: + return { + "pull_request": self.pull_request.meta(), + "repository": self.repository.meta(), + "action": str(self.action), + } + class PullRequestReviewAction(SerializableEnum): SUBMITTED = "submitted" diff --git a/github/commands.py b/github/commands.py index c4f00f6..f9e2d1d 100644 --- a/github/commands.py +++ b/github/commands.py @@ -168,6 +168,8 @@ async def handle_reaction(self, evt: ReactionEvent, client: GitHubClient, return if webhook_meta["event_type"] == "issues" and webhook_meta["action"] == "opened": subject_id = webhook_meta["issue"]["node_id"] + elif webhook_meta["event_type"] == "pull_request" and webhook_meta["action"] == "opened": + subject_id = webhook_meta["pull_request"]["node_id"] elif webhook_meta["event_type"] == "issue_comment" and webhook_meta["action"] == "created": subject_id = webhook_meta["comment"]["node_id"] else: From 2d1099df881a1577cad9add449caf3549e50e6f9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 12:22:11 +0300 Subject: [PATCH 11/43] Don't respond with GraphQL errors for reactions --- github/commands.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/github/commands.py b/github/commands.py index f9e2d1d..1dfdbd6 100644 --- a/github/commands.py +++ b/github/commands.py @@ -39,11 +39,12 @@ async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: try: return await fn(self, evt, **kwargs, client=client) except GraphQLError as e: - if e.type == "INSUFFICIENT_SCOPES": - await evt.reply("Your login doesn't have sufficient access to do that. " - "Try adding more permissions with `!github login`.") - else: - await evt.reply(str(e)) + if error: + if e.type == "INSUFFICIENT_SCOPES": + await evt.reply("Your login doesn't have sufficient access to do that. " + "Try adding more permissions with `!github login`.") + else: + await evt.reply(str(e)) return wrapper From bec7ea0513a28de7924aadd7ecb7621c7e20b376 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 17:52:56 +0300 Subject: [PATCH 12/43] Make sure issue ID matches when aggregating events --- github/api/types.py | 12 ++++++++++++ github/webhook/aggregation.py | 6 +++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/github/api/types.py b/github/api/types.py index 87723a9..b0f75ed 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -434,6 +434,10 @@ class IssuesEvent(SerializableAttrs): milestone: Optional[Milestone] = None changes: Optional[JSON] = None + @property + def issue_id(self) -> int: + return self.issue.id + def meta(self) -> Dict[str, Any]: return { "issue": self.issue.meta(), @@ -474,6 +478,10 @@ class IssueCommentEvent(SerializableAttrs): repository: Repository sender: User + @property + def issue_id(self) -> int: + return self.issue.id + def meta(self) -> Dict[str, Any]: return { "issue": self.issue.meta(), @@ -806,6 +814,10 @@ class PullRequestEvent(SerializableAttrs): milestone: Optional[Milestone] = None requested_reviewer: Optional[User] = None + @property + def issue_id(self) -> int: + return self.pull_request.id + def meta(self) -> Dict[str, Any]: return { "pull_request": self.pull_request.meta(), diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 017a1d2..882a2d3 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -130,6 +130,7 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: postpone = True if (evt_type == EventType.ISSUES and self.event_type == EventType.ISSUE_COMMENT and evt.action in (IssueAction.CLOSED, IssueAction.REOPENED) + and evt.issue_id == self.event.issue_id and self.event.sender.id == evt.sender.id): if evt.action == IssueAction.CLOSED: self.aggregation["closed"] = True @@ -137,6 +138,7 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: self.aggregation["reopened"] = True elif (evt_type == EventType.ISSUE_COMMENT and self.event_type == EventType.ISSUES and evt.action == CommentAction.CREATED and self.event.sender.id == evt.sender.id + and evt.issue_id == self.event.issue_id and self.event.action in (IssueAction.CLOSED, IssueAction.REOPENED)): self.event_type = evt_type self.event = evt @@ -147,7 +149,9 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: elif evt_type != self.event_type: return False elif self.event_type in (EventType.ISSUES, EventType.PULL_REQUEST): - if self.event.action == self.action_type.OPENED and evt.label and evt.label.id in self._label_ids: + if (self.event.action == self.action_type.OPENED + and evt.issue_id == self.event.issue_id + and evt.label and evt.label.id in self._label_ids): # Label was already in original event, drop the message. pass elif self.event.action == self.action_type.X_LABEL_AGGREGATE: From 21848cf181e84f94cf002fe51018ec0adc5ce31b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 17:55:30 +0300 Subject: [PATCH 13/43] Move issue ID check up --- github/webhook/aggregation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 882a2d3..cfc3401 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -138,7 +138,7 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: self.aggregation["reopened"] = True elif (evt_type == EventType.ISSUE_COMMENT and self.event_type == EventType.ISSUES and evt.action == CommentAction.CREATED and self.event.sender.id == evt.sender.id - and evt.issue_id == self.event.issue_id + and evt.issue_id == self.event.issue_id and self.event.action in (IssueAction.CLOSED, IssueAction.REOPENED)): self.event_type = evt_type self.event = evt @@ -148,9 +148,9 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: self.aggregation["reopened"] = True elif evt_type != self.event_type: return False - elif self.event_type in (EventType.ISSUES, EventType.PULL_REQUEST): + elif (self.event_type in (EventType.ISSUES, EventType.PULL_REQUEST) + and evt.issue_id == self.event.issue_id): if (self.event.action == self.action_type.OPENED - and evt.issue_id == self.event.issue_id and evt.label and evt.label.id in self._label_ids): # Label was already in original event, drop the message. pass From fa11763123361eeb4e91f0d1e3cce848fbb2f381 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 4 Aug 2021 18:04:09 +0300 Subject: [PATCH 14/43] Increase webhook aggregation timeout --- github/webhook/aggregation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index cfc3401..08e644b 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -64,7 +64,7 @@ def start_milestone_aggregation(self) -> None: (EventType.ISSUES, IssueAction.CLOSED): noop, } - timeout = 1 + timeout = 3 handler: 'WebhookHandler' webhook_info: WebhookInfo From a65513076224398e7525a174ca44e34c8c87924c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Aug 2021 13:21:58 +0300 Subject: [PATCH 15/43] Allow HTML in markdown from GitHub --- github/template/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/github/template/manager.py b/github/template/manager.py index 638fd6a..ab33ed9 100644 --- a/github/template/manager.py +++ b/github/template/manager.py @@ -32,7 +32,7 @@ def __init__(self, config: Config, key: str) -> None: self._loader = ConfigTemplateLoader(config, key) self._env = JinjaEnvironment(loader=self._loader, lstrip_blocks=True, trim_blocks=True, extensions=["jinja2.ext.do"]) - self._env.filters["markdown"] = markdown.render + self._env.filters["markdown"] = lambda message: markdown.render(message, allow_html=True) def __getitem__(self, item: str) -> Template: return self._env.get_template(item) From 656cf2ec9740b71a52fc955a798a6056ad0e7f78 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 10 Aug 2021 13:23:25 +0300 Subject: [PATCH 16/43] Escape HTML in commit messages --- base-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/base-config.yaml b/base-config.yaml index 9ff58a8..5523ba2 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -244,7 +244,7 @@ messages: {% if commit.distinct %}
  • {{ commit.id[:8] }} - {{ util.cut_message(commit.message) }} + {{ util.cut_message(commit.message)|e }} {% if commit.author.username != sender.login %} by {{ commit_user_link(commit.author) }} {% endif %} From e68194ac2a2a6d2bec71061b67b535c2a5b17ddd Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 15:35:33 +0200 Subject: [PATCH 17/43] Update CI artifact expiry --- .gitlab-ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index c649b91..45ef06b 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -15,6 +15,7 @@ build: artifacts: paths: - "*.mbp" + expire_in: 365 days build tags: stage: build @@ -25,3 +26,4 @@ build tags: artifacts: paths: - "*.mbp" + expire_in: never From d8ad9032f5af58f17457cb655cd5bae8559e0097 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 18:42:08 +0200 Subject: [PATCH 18/43] Add support for resetting user tokens --- base-config.yaml | 4 ++++ github/api/client.py | 33 ++++++++++++++++++++++++++------- github/bot.py | 31 ++++++++++++++++++++++++++++++- github/client_manager.py | 22 ++++++++++++++++++++-- github/config.py | 3 ++- 5 files changed, 82 insertions(+), 11 deletions(-) diff --git a/base-config.yaml b/base-config.yaml index 5523ba2..60dffea 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -9,6 +9,10 @@ webhook_key: generate # This is useful if you're self-hosting, but don't want to set up the GitHub OAuth stuff. global_webhook_secret: null +# Set this to true to regenerate all user tokens using GitHub's reset token API. +# It will switch back to false automatically. +reset_tokens: false + command_options: # Prefix for all the bot commands. Does not include the ! prefix: diff --git a/github/api/client.py b/github/api/client.py index af4b85b..657e750 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -1,5 +1,5 @@ # github - A maubot plugin to act as a GitHub client and webhook receiver. -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -21,8 +21,8 @@ from aiohttp import ClientSession from yarl import URL -from github.api.types import Webhook from ..util import recursive_get +from .types import Webhook OptStrList = Optional[List[str]] @@ -121,6 +121,13 @@ def headers(self) -> Dict[str, str]: "Accept": "application/json", } + @property + def rest_v3_headers(self) -> Dict[str, str]: + return { + **self.headers, + "Accept": "application/vnd.github.v3+json", + } + async def call_raw(self, query: str, variables: Optional[Dict] = None) -> dict: resp = await self.http.post(self.api_url, json={ @@ -130,14 +137,24 @@ async def call_raw(self, query: str, variables: Optional[Dict] = None) -> dict: headers=self.headers) return await resp.json() + async def reset_token(self) -> Optional[str]: + url = ((self.base_url / "applications" / self.client_id / "token") + .with_user(self.client_id).with_password(self.client_secret)) + resp = await self.http.patch(url, json={"access_token": self.token}) + resp_data = await resp.json() + if resp.status == 404: + return None + self.token = resp_data["token"] + return self.token + async def list_webhooks(self, owner: str, repo: str) -> List[Webhook]: resp = await self.http.get(self.base_url / "repos" / owner / repo / "hooks", - headers=self.headers) + headers=self.rest_v3_headers) return [Webhook.deserialize(info) for info in await resp.json()] async def get_webhook(self, owner: str, repo: str, hook_id: int) -> Webhook: resp = await self.http.get(self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), - headers=self.headers) + headers=self.rest_v3_headers) data = await resp.json() if resp.status != 200: raise GitHubError(status=resp.status, **data) @@ -158,7 +175,7 @@ async def create_webhook(self, owner: str, repo: str, url: URL, *, active: bool "active": active, } resp = await self.http.post(self.base_url / "repos" / owner / repo / "hooks", - data=json.dumps(payload), headers=self.headers) + data=json.dumps(payload), headers=self.rest_v3_headers) data = await resp.json() if resp.status != 201: raise GitHubError(status=resp.status, **data) @@ -192,7 +209,7 @@ async def edit_webhook(self, owner: str, repo: str, hook_id: int, *, url: Option payload["config"] = config resp = await self.http.patch( self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), - data=json.dumps(payload), headers=self.headers) + data=json.dumps(payload), headers=self.rest_v3_headers) data = await resp.json() if resp.status != 200: raise GitHubError(status=resp.status, **data) @@ -200,7 +217,9 @@ async def edit_webhook(self, owner: str, repo: str, hook_id: int, *, url: Option async def delete_webhook(self, owner: str, repo: str, hook_id: int) -> None: resp = await self.http.delete( - self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), headers=self.headers) + self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), + headers=self.rest_v3_headers, + ) if resp.status != 204: data = await resp.json() raise GitHubError(status=resp.status, **data) diff --git a/github/bot.py b/github/bot.py index 131ac8d..f2b5107 100644 --- a/github/bot.py +++ b/github/bot.py @@ -1,5 +1,5 @@ # github - A maubot plugin to act as a GitHub client and webhook receiver. -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,6 +14,7 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Type +import asyncio from sqlalchemy import MetaData @@ -56,6 +57,31 @@ async def start(self) -> None: self.register_handler_class(self.clients) self.register_handler_class(self.commands) + async def reset_tokens(self) -> None: + try: + await self._reset_tokens() + except Exception: + self.log.exception("Error resetting user tokens") + + async def _reset_tokens(self) -> None: + self.config["reset_tokens"] = False + self.config.save() + self.log.info("Resetting all user tokens") + for user_id, client in self.clients.get_all().items(): + self.log.debug(f"Resetting {user_id}'s token...") + try: + new_token = await client.reset_token() + except Exception: + self.log.warning(f"Failed to reset {user_id}'s token", exc_info=True) + else: + if new_token is None: + self.log.debug(f"{user_id}'s token was not valid, removing from database") + self.clients.remove(user_id) + else: + self.log.debug(f"Successfully reset {user_id}'s token") + self.clients.put(user_id, new_token) + self.log.debug("Finished resetting all user tokens") + def on_external_config_update(self) -> None: self.config.load_and_update() self.clients.client_id = self.config["client_id"] @@ -63,6 +89,9 @@ def on_external_config_update(self) -> None: self.webhook_handler.reload_config() self.commands.reload_config() + if self.config["reset_tokens"]: + asyncio.create_task(self.reset_tokens()) + @classmethod def get_config_class(cls) -> Type[Config]: return Config diff --git a/github/client_manager.py b/github/client_manager.py index 5d92777..b8f840a 100644 --- a/github/client_manager.py +++ b/github/client_manager.py @@ -55,11 +55,19 @@ def _make(self, token: str) -> GitHubClient: client_secret=self.client_secret, token=token) - def _save(self, user_id: UserID, token: str) -> None: + def put(self, user_id: UserID, token: str) -> None: with self._db.begin() as conn: conn.execute(self._table.delete().where(self._table.c.user_id == user_id)) conn.execute(self._table.insert().values(user_id=user_id, token=token)) + def remove(self, user_id: UserID) -> None: + with self._db.begin() as conn: + self._clients.pop(user_id, None) + conn.execute(self._table.delete().where(self._table.c.user_id == user_id)) + + def get_all(self) -> Dict[UserID, GitHubClient]: + return self._clients + def get(self, user_id: UserID, create: bool = False) -> Optional[GitHubClient]: try: return self._clients[user_id] @@ -73,6 +81,16 @@ def get(self, user_id: UserID, create: bool = False) -> Optional[GitHubClient]: @web_handler.get("/auth") async def login_callback(self, request: web.Request) -> web.Response: # TODO fancy webpages here + try: + error_code = request.query["error"] + error_msg = request.query["error_description"] + error_uri = request.query.get("error_uri", "") + except KeyError: + pass + else: + return web.Response(status=400, text=f"Failed to log in: {error_code}\n\n" + f"{error_msg}\n\n" + f"More info at {error_uri}") try: user_id = UserID(request.query["user_id"]) code = request.query["code"] @@ -90,5 +108,5 @@ async def login_callback(self, request: web.Request) -> web.Response: return web.Response(status=401, text="Failed to finish login") resp = await client.query("viewer { login }") user = resp["viewer"]["login"] - self._save(user_id, client.token) + self.put(user_id, client.token) return web.Response(status=200, text=f"Logged in as {user}") diff --git a/github/config.py b/github/config.py index c35dcc7..edb3d40 100644 --- a/github/config.py +++ b/github/config.py @@ -1,5 +1,5 @@ # github - A maubot plugin to act as a GitHub client and webhook receiver. -# Copyright (C) 2020 Tulir Asokan +# Copyright (C) 2021 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -29,6 +29,7 @@ def do_update(self, helper: ConfigUpdateHelper) -> None: if helper.source.get("webhook_key", "generate") == "generate" else helper.source["webhook_key"]) helper.copy("global_webhook_secret") + helper.copy("reset_tokens") helper.copy("command_options.prefix") helper.copy("message_options.msgtype") helper.copy("message_options.aggregation_timeout") From b2c4424ca74d3dbd55f7a51097948b65271e888e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 18:43:52 +0200 Subject: [PATCH 19/43] Update meta --- maubot.yaml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/maubot.yaml b/maubot.yaml index 52df716..cdd6ac7 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -9,7 +9,4 @@ extra_files: - base-config.yaml database: true webapp: true -#dependencies: -#- foo -#soft_dependencies: -#- bar>=0.1 +config: true From a9f1313edfdb3dec1109fc5bb5d2b965c27f714e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 18:46:37 +0200 Subject: [PATCH 20/43] Fix deleting tokens while resetting them --- github/client_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/github/client_manager.py b/github/client_manager.py index b8f840a..52a0d38 100644 --- a/github/client_manager.py +++ b/github/client_manager.py @@ -66,7 +66,7 @@ def remove(self, user_id: UserID) -> None: conn.execute(self._table.delete().where(self._table.c.user_id == user_id)) def get_all(self) -> Dict[UserID, GitHubClient]: - return self._clients + return self._clients.copy() def get(self, user_id: UserID, create: bool = False) -> Optional[GitHubClient]: try: From 197c4da916a232286ef0c386e6f11f4237d64e96 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 18:51:38 +0200 Subject: [PATCH 21/43] Add command to invalidate and remove GitHub token --- github/api/client.py | 12 +++++++++--- github/commands.py | 7 +++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/github/api/client.py b/github/api/client.py index 657e750..27ecbb7 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -137,16 +137,22 @@ async def call_raw(self, query: str, variables: Optional[Dict] = None) -> dict: headers=self.headers) return await resp.json() - async def reset_token(self) -> Optional[str]: - url = ((self.base_url / "applications" / self.client_id / "token") + @property + def _token_url(self) -> URL: + return ((self.base_url / "applications" / self.client_id / "token") .with_user(self.client_id).with_password(self.client_secret)) - resp = await self.http.patch(url, json={"access_token": self.token}) + + async def reset_token(self) -> Optional[str]: + resp = await self.http.patch(self._token_url, json={"access_token": self.token}) resp_data = await resp.json() if resp.status == 404: return None self.token = resp_data["token"] return self.token + async def delete_token(self) -> None: + await self.http.delete(self._token_url, json={"access_token": self.token}) + async def list_webhooks(self, owner: str, repo: str) -> List[Webhook]: resp = await self.http.get(self.base_url / "repos" / owner / repo / "hooks", headers=self.rest_v3_headers) diff --git a/github/commands.py b/github/commands.py index 1dfdbd6..5a10edb 100644 --- a/github/commands.py +++ b/github/commands.py @@ -129,6 +129,13 @@ async def login(self, evt: MessageEvent, flags: str, client: Optional[GitHubClie else: await evt.reply(f"[Click here to log in]({login_url})") + @github.subcommand("logout", help="Delete the stored GitHub access token.") + @authenticated + async def logout(self, evt: MessageEvent, client: GitHubClient) -> None: + await client.delete_token() + self.bot.clients.remove(evt.sender) + await evt.reply("Successfully logged out") + @event.on(EventType.ROOM_MESSAGE) @authenticated(error=False) @with_webhook_meta(RelationType.REFERENCE) From 286c64ff601046310dc7f168b26e1a5555a79d2d Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 18:51:44 +0200 Subject: [PATCH 22/43] Bump version to v0.1.1 --- maubot.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maubot.yaml b/maubot.yaml index cdd6ac7..6ef4da5 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,6 +1,6 @@ maubot: 0.1.0 id: xyz.maubot.github -version: 0.1.0 +version: 0.1.1 license: AGPL-3.0-or-later modules: - github From 23480a120c8ea6868e56fa4de1b748116668add9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 28 Nov 2021 19:41:11 +0200 Subject: [PATCH 23/43] Fix pull request change request events not being parsed correctly --- base-config.yaml | 3 +-- github/api/types.py | 12 +++++++++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/base-config.yaml b/base-config.yaml index 60dffea..94ec9b1 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -196,8 +196,7 @@ messages: pull_request_review: > {{ templates.repo_sender_prefix }} {% if action == SUBMITTED %} - {{ review.state }} - {% if review.state == ReviewState.COMMENTED %} on {% endif %} + {{ review.state.action_str }} {{ templates.pr_link }} {%- if review.body %} :
    {{ review.body|markdown }}
    diff --git a/github/api/types.py b/github/api/types.py index b0f75ed..3b9771d 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -835,7 +835,17 @@ class PullRequestReviewAction(SerializableEnum): class ReviewState(SerializableEnum): COMMENTED = "commented" APPROVED = "approved" - REJECTED = "rejected" + REJECTED = "rejected" # TODO: is this an actual state? + CHANGES_REQUESTED = "changes_requested" + + @property + def action_str(self) -> str: + if self == ReviewState.CHANGES_REQUESTED: + return "requested changes on" + elif self == ReviewState.COMMENTED: + return "commented on" + else: + return self.value @dataclass From 96ca1f69296da3b1444e3abf03ad332eb574c50c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Mon, 31 Jan 2022 16:12:33 +0200 Subject: [PATCH 24/43] Respond 202 on unknown event types --- github/api/webhook.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/github/api/webhook.py b/github/api/webhook.py index d8d0105..fb7d8c4 100644 --- a/github/api/webhook.py +++ b/github/api/webhook.py @@ -103,7 +103,7 @@ async def _handle(self, request: web.Request, webhook_info: 'WebhookInfo') -> we except KeyError as e: return web.Response(status=400, text=f"Missing {e.args[0]} header") except ValueError: - return web.Response(status=500, text="Unsupported event type") + return web.Response(status=202, text="Unsupported event type") text = await request.text() text_binary = text.encode("utf-8") secret = webhook_info.secret.encode("utf-8") @@ -119,7 +119,7 @@ async def _handle(self, request: web.Request, webhook_info: 'WebhookInfo') -> we try: type_class = EVENT_CLASSES[event_type] except KeyError: - return web.Response(status=500, text="Unsupported event type") + return web.Response(status=500, text="Content class not found") try: event = type_class.deserialize(data) except SerializerError: From 18425903dbda3fe4f00db157e7639a00671f2511 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 4 Mar 2022 16:18:45 +0200 Subject: [PATCH 25/43] Remove space removal (moved to HTML parser) --- github/webhook/aggregation.py | 4 ++-- github/webhook/handler.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 08e644b..573afaf 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -17,8 +17,8 @@ import asyncio from .manager import WebhookInfo -from github.api.types import (Event, EventType, Action, IssueAction, PullRequestAction, - CommentAction, ACTION_CLASSES) +from ..api.types import (Event, EventType, Action, IssueAction, PullRequestAction, CommentAction, + ACTION_CLASSES) if TYPE_CHECKING: from .handler import WebhookHandler diff --git a/github/webhook/handler.py b/github/webhook/handler.py index f44ea37..2a37c4b 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -33,7 +33,7 @@ from .aggregation import PendingAggregation if TYPE_CHECKING: - from github.bot import GitHubBot + from ..bot import GitHubBot spaces = re.compile(" +") space = " " @@ -118,12 +118,11 @@ def abort() -> None: "aggregation": aggregation, } args["templates"] = self.templates.proxy(args) - content = TextMessageEventContent(msgtype=self.msgtype, format=Format.HTML, - formatted_body=tpl.render(**args)) - if not content.formatted_body or aborted: + html = tpl.render(**args) + if not html or aborted: return - content.formatted_body = spaces.sub(space, content.formatted_body.strip()) - content.body = parse_html(content.formatted_body) + content = TextMessageEventContent(msgtype=self.msgtype, format=Format.HTML, + formatted_body=html, body=await parse_html(html.strip())) content["xyz.maubot.github.webhook"] = { "delivery_ids": list(delivery_ids), "event_type": str(evt_type), From f05fdbde476e8f834e8882d2bf94009fb075ed9a Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 5 Apr 2022 18:43:19 +0300 Subject: [PATCH 26/43] Disable webhook message link previews in Beeper clients --- github/webhook/handler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/github/webhook/handler.py b/github/webhook/handler.py index 2a37c4b..5d23ffe 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -128,4 +128,5 @@ def abort() -> None: "event_type": str(evt_type), **(evt.meta() if hasattr(evt, "meta") else {}), } + content["com.beeper.linkpreviews"] = [] await self.bot.client.send_message(room_id, content) From b99f277e39bf7b313be13d5f67d727f98ab648a8 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 5 Apr 2022 18:43:31 +0300 Subject: [PATCH 27/43] Bump version to v0.1.2 --- maubot.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/maubot.yaml b/maubot.yaml index 6ef4da5..27a4d03 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,6 +1,6 @@ -maubot: 0.1.0 +maubot: 0.3.0 id: xyz.maubot.github -version: 0.1.1 +version: 0.1.2 license: AGPL-3.0-or-later modules: - github From af018940cb1fe6102806587200ef32ffbf5cf47b Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 23 Apr 2022 11:12:44 +0300 Subject: [PATCH 28/43] Add custom message for PR synchronize event --- base-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/base-config.yaml b/base-config.yaml index 94ec9b1..f913185 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -170,6 +170,7 @@ messages: closed {% endif %} {% elif action == SYNCHRONIZE and pull_request.head.repo.id == pull_request.base.repo.id %} {% do abort() %} + {% elif action == SYNCHRONIZE %}pushed something to {% elif action == LABELED %} {% do abort() %} {% elif action == UNLABELED %} {% do abort() %} {% elif action == X_LABEL_AGGREGATE %} {{ templates.label_aggregation }} From b3c2fa393e4ceec40f30f57e52054fb9a634b5b7 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 26 May 2022 18:37:15 +0300 Subject: [PATCH 29/43] Reply with error if webhook creation fails --- github/api/client.py | 1 + github/commands.py | 15 ++++++++++++--- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/github/api/client.py b/github/api/client.py index 27ecbb7..9d5dcae 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -33,6 +33,7 @@ def __init__(self, message: str, documentation_url: str, status: int, **kwargs) self.documentation_url = documentation_url self.status = status self.kwargs = kwargs + self.message = message class GraphQLError(Exception): diff --git a/github/commands.py b/github/commands.py index 5a10edb..7b91e57 100644 --- a/github/commands.py +++ b/github/commands.py @@ -258,9 +258,18 @@ async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: # TODO webhook may be deleted on github side return webhook = self.bot.webhook_manager.create(repo_name, evt.sender, evt.room_id) - await client.create_webhook(*repo, url=self.bot.webapp_url / "webhook" / str(webhook.id), - secret=webhook.secret, content_type="json", events=["*"]) - await evt.reply(f"Successfully created webhook for {repo_name}") + try: + await client.create_webhook( + *repo, url=self.bot.webapp_url / "webhook" / str(webhook.id), + secret=webhook.secret, + content_type="json", + events=["*"], + ) + except GitHubError as e: + await evt.reply(f"Failed to create webhook: {e.message}") + self.bot.webhook_manager.delete(webhook.id) + else: + await evt.reply(f"Successfully created webhook for {repo_name}") @webhook.subcommand("remove", aliases=["delete", "rm", "del"]) @command.argument("repo", required=True, matches=repo_syntax, label="owner/repo") From 070df74151c3ff8ef908c0455f3f579f0365cb54 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Fri, 27 May 2022 10:46:19 +0300 Subject: [PATCH 30/43] Add state_reason field for issues --- github/api/types.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/github/api/types.py b/github/api/types.py index 3b9771d..2b740cc 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -321,6 +321,11 @@ class IssueState(SerializableEnum): CLOSED = "closed" +class IssueStateReason(SerializableEnum): + COMPLETED = "completed" + NOT_PLANNED = "not_planned" + + @dataclass class Milestone(SerializableAttrs): id: int @@ -363,6 +368,7 @@ class Issue(SerializableAttrs): author_association: str labels: List[Label] state: IssueState + state_reason: Optional[IssueStateReason] locked: bool milestone: Optional[Milestone] From 30f2552682ad0b8ae69a3f6b4f4d6a100983fecb Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 19 Jun 2022 14:24:42 +0300 Subject: [PATCH 31/43] Move CI script to main maubot repo --- .gitlab-ci.yml | 32 +++----------------------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 45ef06b..7c690ef 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,29 +1,3 @@ -image: dock.mau.dev/maubot/maubot - -stages: -- build - -variables: - PYTHONPATH: /opt/maubot - -build: - stage: build - except: - - tags - script: - - python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_REF_NAME-$CI_COMMIT_SHORT_SHA.mbp - artifacts: - paths: - - "*.mbp" - expire_in: 365 days - -build tags: - stage: build - only: - - tags - script: - - python3 -m maubot.cli build -o xyz.maubot.$CI_PROJECT_NAME-$CI_COMMIT_TAG.mbp - artifacts: - paths: - - "*.mbp" - expire_in: never +include: +- project: 'maubot/maubot' + file: '/.gitlab-ci-plugin.yml' From b2f2b12a2c2100cec1bcb7b0aae5a7ccc7bfde15 Mon Sep 17 00:00:00 2001 From: Sumner Evans Date: Sat, 20 Aug 2022 09:36:02 -0600 Subject: [PATCH 32/43] Add support for build status emojis Signed-off-by: Sumner Evans --- github/api/types.py | 143 +++++++++++++++++++++++++++++++--- github/bot.py | 3 + github/db.py | 61 +++++++++++++++ github/webhook/aggregation.py | 11 ++- github/webhook/handler.py | 89 +++++++++++++++++---- 5 files changed, 279 insertions(+), 28 deletions(-) create mode 100644 github/db.py diff --git a/github/api/types.py b/github/api/types.py index 2b740cc..60bf41a 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -215,6 +215,12 @@ class PushEvent(SerializableAttrs): size: int = None distinct_size: int = None + @property + def message_id(self) -> str: + if not self.head_commit: + return "" + return f"push-{self.repository.id}-{self.head_commit.id}" + @dataclass class ReleaseAsset(SerializableAttrs): @@ -960,6 +966,89 @@ class RepositoryEvent(SerializableAttrs): changes: Optional[JSON] = None +class WorkflowJobAction(SerializableEnum): + QUEUED = "queued" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + + +class WorkflowConclusion(SerializableEnum): + SUCCESS = "success" + FAILURE = "failure" + NEUTRAL = "neutral" + CANCELLED = "cancelled" + TIMED_OUT = "timed_out" + ACTION_REQUIRED = "action_required" + STALE = "stale" + + +@dataclass +class WorkflowJob(SerializableAttrs): + id: int + run_id: int + run_url: str + name: str + head_sha: str + conclusion: Optional[WorkflowConclusion] = None + + @property + def meta(self) -> JSON: + info = { + "id": self.id, + "run_id": self.run_id, + "name": self.name, + "url": self.run_url, + } + if self.conclusion: + info["conclusion"] = self.conclusion.name + return info + + +_build_status_circles: Dict[WorkflowJobAction, Union[Dict[WorkflowConclusion, str], str]] = { + WorkflowJobAction.QUEUED: "🟡", + WorkflowJobAction.IN_PROGRESS: "🔵", + WorkflowJobAction.COMPLETED: { + WorkflowConclusion.SUCCESS: "🟢", + WorkflowConclusion.FAILURE: "🔴", + WorkflowConclusion.NEUTRAL: "⚪", + WorkflowConclusion.CANCELLED: "⚫️", + WorkflowConclusion.TIMED_OUT: "⏱️", + WorkflowConclusion.ACTION_REQUIRED: "⚠️", + WorkflowConclusion.STALE: "⚪", + }, +} + + +@dataclass +class WorkflowJobEvent(SerializableAttrs): + action: WorkflowJobAction + workflow_job: WorkflowJob + repository: Repository + sender: User + + organization: Optional[Organization] = None + + @property + def push_id(self) -> str: + return f"push-{self.repository.id}-{self.workflow_job.head_sha}" + + @property + def reaction_id(self) -> str: + return f"job-{self.repository.id}-{self.workflow_job.head_sha}-{self.workflow_job.id}" + + @property + def color_circle(self) -> str: + circle_def = _build_status_circles[self.action] + if isinstance(circle_def, str): + return circle_def + else: + return circle_def[self.workflow_job.conclusion] + + @property + def meta(self) -> JSON: + return {"build": self.workflow_job.meta} + + class EventType(SerializableEnum): ISSUES = "issues" ISSUE_COMMENT = "issue_comment" @@ -981,16 +1070,48 @@ class EventType(SerializableEnum): PULL_REQUEST_REVIEW = "pull_request_review" PULL_REQUEST_REVIEW_COMMENT = "pull_request_review_comment" REPOSITORY = "repository" - - -Event = Union[IssuesEvent, IssueCommentEvent, PushEvent, ReleaseEvent, StarEvent, WatchEvent, - PingEvent, ForkEvent, CreateEvent, MetaEvent, CommitCommentEvent, MilestoneEvent, - LabelEvent, WikiEvent, PublicEvent, PullRequestEvent, PullRequestReviewEvent, - PullRequestReviewCommentEvent, RepositoryEvent, DeleteEvent] - -Action = Union[IssueAction, StarAction, CommentAction, WikiPageAction, MetaAction, ReleaseAction, - PullRequestAction, PullRequestReviewAction, PullRequestReviewCommentAction, - MilestoneAction, LabelAction, RepositoryAction] + WORKFLOW_JOB = "workflow_job" + + +Event = Union[ + IssuesEvent, + IssueCommentEvent, + PushEvent, + ReleaseEvent, + StarEvent, + WatchEvent, + PingEvent, + ForkEvent, + CreateEvent, + MetaEvent, + CommitCommentEvent, + MilestoneEvent, + LabelEvent, + WikiEvent, + PublicEvent, + PullRequestEvent, + PullRequestReviewEvent, + PullRequestReviewCommentEvent, + RepositoryEvent, + DeleteEvent, + WorkflowJobEvent, +] + +Action = Union[ + IssueAction, + StarAction, + CommentAction, + WikiPageAction, + MetaAction, + ReleaseAction, + PullRequestAction, + PullRequestReviewAction, + PullRequestReviewCommentAction, + MilestoneAction, + LabelAction, + RepositoryAction, + WorkflowJobAction, +] EVENT_CLASSES = { EventType.ISSUES: IssuesEvent, @@ -1013,6 +1134,7 @@ class EventType(SerializableEnum): EventType.PULL_REQUEST_REVIEW: PullRequestReviewEvent, EventType.PULL_REQUEST_REVIEW_COMMENT: PullRequestReviewCommentEvent, EventType.REPOSITORY: RepositoryEvent, + EventType.WORKFLOW_JOB: WorkflowJobEvent, } @@ -1036,6 +1158,7 @@ def expand_enum(enum: Type[SerializableEnum]) -> Dict[str, SerializableEnum]: EventType.MILESTONE: MilestoneAction, EventType.LABEL: LabelAction, EventType.REPOSITORY: RepositoryAction, + EventType.WORKFLOW_JOB: WorkflowJobAction, } OTHER_ENUMS = { diff --git a/github/bot.py b/github/bot.py index f2b5107..f5c0a18 100644 --- a/github/bot.py +++ b/github/bot.py @@ -20,6 +20,7 @@ from maubot import Plugin +from .db import Database from .webhook import WebhookManager, WebhookHandler from .client_manager import ClientManager from .api import GitHubWebhookReceiver @@ -28,6 +29,7 @@ class GitHubBot(Plugin): + db: Database webhook_receiver: GitHubWebhookReceiver webhook_manager: WebhookManager webhook_handler: WebhookHandler @@ -40,6 +42,7 @@ async def start(self) -> None: metadata = MetaData() + self.db = Database(self.database) self.clients = ClientManager(self.config["client_id"], self.config["client_secret"], self.http, self.database, metadata) self.webhook_manager = WebhookManager(self.config["webhook_key"], diff --git a/github/db.py b/github/db.py new file mode 100644 index 0000000..8753c96 --- /dev/null +++ b/github/db.py @@ -0,0 +1,61 @@ +# github - A maubot plugin to act as a GitHub client and webhook receiver. +# Copyright (C) 2022 Sumner Evans +# Copyright (C) 2022 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from typing import List, NamedTuple, Optional +import logging as log + +from sqlalchemy import Column, String, Text, ForeignKeyConstraint, or_, ForeignKey +from sqlalchemy.orm import sessionmaker, relationship, Session +from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound +from sqlalchemy.engine.base import Engine +from sqlalchemy.ext.declarative import declarative_base + +from mautrix.types import UserID, EventID, RoomID + +Base = declarative_base() + + +class MatrixMessage(Base): + __tablename__ = "matrix_message" + + message_id: str = Column(String(255), primary_key=True) + room_id: RoomID = Column(String(255), primary_key=True) + event_id: EventID = Column(String(255), nullable=False) + + +class Database: + db: Engine + + def __init__(self, db: Engine) -> None: + self.db = db + Base.metadata.create_all(db) + self.Session = sessionmaker(bind=self.db) + + def get_event(self, message_id: str, room_id: RoomID) -> Optional[EventID]: + s: Session = self.Session() + event = s.query(MatrixMessage).get((message_id, room_id)) + return event.event_id if event else None + + def put_event( + self, message_id: str, room_id: RoomID, event_id: EventID, merge: bool = False + ) -> None: + s: Session = self.Session() + evt = MatrixMessage(message_id=message_id, room_id=room_id, event_id=event_id) + if merge: + s.merge(evt) + else: + s.add(evt) + s.commit() diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 573afaf..43ee42f 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -123,8 +123,15 @@ async def _sleep(self) -> None: pass async def _send(self) -> None: - await self.handler.send_message(self.event_type, self.event, self.webhook_info.room_id, - self.delivery_ids, aggregation=self.aggregation) + event_id = await self.handler.send_message( + self.event_type, + self.event, + self.webhook_info.room_id, + self.delivery_ids, + aggregation=self.aggregation, + ) + if self.event_type == EventType.PUSH and event_id: + self.handler.bot.db.put_event(self.event.message_id, self.webhook_info.room_id, event_id) def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: postpone = True diff --git a/github/webhook/handler.py b/github/webhook/handler.py index 5d23ffe..a5f63af 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Set, Deque, Optional, Any, Callable, TYPE_CHECKING +from typing import Dict, Set, Deque, Optional, Any, TYPE_CHECKING from collections import deque, defaultdict from uuid import UUID import asyncio @@ -23,12 +23,31 @@ from jinja2 import TemplateNotFound import attr -from mautrix.types import TextMessageEventContent, Format, MessageType, RoomID +from mautrix.types import ( + EventID, + EventType as MautrixEventType, + Format, + MessageType, + ReactionEventContent, + RelatesTo, + RelationType, + RoomID, + TextMessageEventContent, +) from mautrix.util.formatter import parse_html from ..template import TemplateManager, TemplateUtil -from ..api.types import (Event, EventType, MetaAction, RepositoryAction, expand_enum, - ACTION_CLASSES, OTHER_ENUMS) +from ..api.types import ( + Event, + EventType, + MetaAction, + PushEvent, + RepositoryAction, + WorkflowJobEvent, + expand_enum, + ACTION_CLASSES, + OTHER_ENUMS, +) from .manager import WebhookInfo from .aggregation import PendingAggregation @@ -80,28 +99,62 @@ async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info self.log.debug(f"Received repo delete hook for {info}") self.bot.webhook_manager.delete(info.id) elif evt_type == EventType.PUSH and (evt.size is None or evt.distinct_size is None): + assert isinstance(evt, PushEvent) evt.size = len(evt.commits) evt.distinct_size = len([commit for commit in evt.commits if commit.distinct]) + elif evt_type == EventType.WORKFLOW_JOB: + assert isinstance(evt, WorkflowJobEvent) + push_evt = self.bot.db.get_event(evt.push_id, info.room_id) + if not push_evt: + self.bot.log.debug(f"No message found to react to push {evt.push_id}") + return + reaction = ReactionEventContent( + RelatesTo(rel_type=RelationType.ANNOTATION, event_id=push_evt) + ) + try: + reaction.relates_to.key = f"{evt.color_circle} {evt.workflow_job.name}" + except KeyError: + return + reaction["xyz.maubot.gitlab.webhook"] = { + "event_type": evt_type.name, + **evt.meta, + } + + prev_reaction = self.bot.db.get_event(evt.reaction_id, info.room_id) + if prev_reaction: + await self.bot.client.redact(info.room_id, prev_reaction) + event_id = await self.bot.client.send_message_event( + info.room_id, MautrixEventType.REACTION, reaction + ) + self.bot.db.put_event( + evt.reaction_id, info.room_id, event_id, merge=prev_reaction is not None + ) if PendingAggregation.timeout < 0: # Aggregations are disabled - await self.send_message(evt_type, evt, info.room_id, {delivery_id}) + event_id = await self.send_message(evt_type, evt, info.room_id, {delivery_id}) + if evt_type == EventType.PUSH and event_id: + self.bot.db.put_event(evt.message_id, info.room_id, event_id) return for pending in self.pending_aggregations[info.id]: if pending.aggregate(evt_type, evt, delivery_id): return - asyncio.ensure_future(PendingAggregation(self, evt_type, evt, delivery_id, info) - .start()) - - async def send_message(self, evt_type: EventType, evt: Event, room_id: RoomID, - delivery_ids: Set[str], aggregation: Optional[Dict[str, Any]] = None - ) -> None: + asyncio.ensure_future(PendingAggregation(self, evt_type, evt, delivery_id, info).start()) + + async def send_message( + self, + evt_type: EventType, + evt: Event, + room_id: RoomID, + delivery_ids: Set[str], + aggregation: Optional[Dict[str, Any]] = None, + ) -> Optional[EventID]: try: tpl = self.messages[str(evt_type)] except TemplateNotFound: self.log.debug(f"Unhandled event of type {evt_type} -- {delivery_ids}") - return + return None aborted = False @@ -120,13 +173,17 @@ def abort() -> None: args["templates"] = self.templates.proxy(args) html = tpl.render(**args) if not html or aborted: - return - content = TextMessageEventContent(msgtype=self.msgtype, format=Format.HTML, - formatted_body=html, body=await parse_html(html.strip())) + return None + content = TextMessageEventContent( + msgtype=self.msgtype, + format=Format.HTML, + formatted_body=html, + body=await parse_html(html.strip()), + ) content["xyz.maubot.github.webhook"] = { "delivery_ids": list(delivery_ids), "event_type": str(evt_type), **(evt.meta() if hasattr(evt, "meta") else {}), } content["com.beeper.linkpreviews"] = [] - await self.bot.client.send_message(room_id, content) + return await self.bot.client.send_message(room_id, content) From b17a052dd175ad7cc2a424d13ecdec6846d9667e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sun, 12 Feb 2023 12:51:03 +0200 Subject: [PATCH 33/43] Use new wrapper for creating background tasks --- github/bot.py | 4 ++-- maubot.yaml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/github/bot.py b/github/bot.py index f5c0a18..dd1f423 100644 --- a/github/bot.py +++ b/github/bot.py @@ -14,11 +14,11 @@ # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . from typing import Type -import asyncio from sqlalchemy import MetaData from maubot import Plugin +from mautrix.util import background_task from .db import Database from .webhook import WebhookManager, WebhookHandler @@ -93,7 +93,7 @@ def on_external_config_update(self) -> None: self.commands.reload_config() if self.config["reset_tokens"]: - asyncio.create_task(self.reset_tokens()) + background_task.create(self.reset_tokens()) @classmethod def get_config_class(cls) -> Type[Config]: diff --git a/maubot.yaml b/maubot.yaml index 27a4d03..30eb78a 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,4 +1,4 @@ -maubot: 0.3.0 +maubot: 0.4.1 id: xyz.maubot.github version: 0.1.2 license: AGPL-3.0-or-later From 79bfdd7f25e7ff7bb20126a22efa97cefd3fa84e Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 30 Nov 2024 19:51:10 +0200 Subject: [PATCH 34/43] Add hack to ignore specific job webhook --- github/webhook/handler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/github/webhook/handler.py b/github/webhook/handler.py index a5f63af..b55cedd 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -103,6 +103,8 @@ async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info evt.size = len(evt.commits) evt.distinct_size = len([commit for commit in evt.commits if commit.distinct]) elif evt_type == EventType.WORKFLOW_JOB: + if evt.workflow_job.name == "lock-stale": + return assert isinstance(evt, WorkflowJobEvent) push_evt = self.bot.db.get_event(evt.push_id, info.room_id) if not push_evt: From 8017ebd356801cf0e496d5ce7dcf2c91dad10fad Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Thu, 2 Jan 2025 23:32:30 +0200 Subject: [PATCH 35/43] Fix parsing error responses --- github/api/client.py | 12 ++++++------ github/commands.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/github/api/client.py b/github/api/client.py index 9d5dcae..115a648 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -28,10 +28,10 @@ class GitHubError(Exception): - def __init__(self, message: str, documentation_url: str, status: int, **kwargs) -> None: + def __init__(self, message: str, documentation_url: str, status_code: int, **kwargs) -> None: super().__init__(message) self.documentation_url = documentation_url - self.status = status + self.status_code = status_code self.kwargs = kwargs self.message = message @@ -164,7 +164,7 @@ async def get_webhook(self, owner: str, repo: str, hook_id: int) -> Webhook: headers=self.rest_v3_headers) data = await resp.json() if resp.status != 200: - raise GitHubError(status=resp.status, **data) + raise GitHubError(status_code=resp.status, **data) return Webhook.deserialize(data) async def create_webhook(self, owner: str, repo: str, url: URL, *, active: bool = True, @@ -185,7 +185,7 @@ async def create_webhook(self, owner: str, repo: str, url: URL, *, active: bool data=json.dumps(payload), headers=self.rest_v3_headers) data = await resp.json() if resp.status != 201: - raise GitHubError(status=resp.status, **data) + raise GitHubError(status_code=resp.status, **data) return Webhook.deserialize(data) async def edit_webhook(self, owner: str, repo: str, hook_id: int, *, url: Optional[URL] = None, @@ -219,7 +219,7 @@ async def edit_webhook(self, owner: str, repo: str, hook_id: int, *, url: Option data=json.dumps(payload), headers=self.rest_v3_headers) data = await resp.json() if resp.status != 200: - raise GitHubError(status=resp.status, **data) + raise GitHubError(status_code=resp.status, **data) return Webhook.deserialize(data) async def delete_webhook(self, owner: str, repo: str, hook_id: int) -> None: @@ -229,4 +229,4 @@ async def delete_webhook(self, owner: str, repo: str, hook_id: int) -> None: ) if resp.status != 204: data = await resp.json() - raise GitHubError(status=resp.status, **data) + raise GitHubError(status_code=resp.status, **data) diff --git a/github/commands.py b/github/commands.py index 7b91e57..398f709 100644 --- a/github/commands.py +++ b/github/commands.py @@ -287,7 +287,7 @@ async def webhook_remove(self, evt: MessageEvent, repo: Tuple[str, str], try: await client.delete_webhook(*repo, hook_id=webhook_info.github_id) except GitHubError as e: - if e.status == 404: + if e.status_code == 404: await evt.reply("Webhook deleted successfully") return else: From 6b7060057e2bd16bbe175e8304e8368bab02a0f6 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Sat, 4 Jan 2025 00:18:20 +0200 Subject: [PATCH 36/43] Fix handling label webhooks before issue open webhook --- github/webhook/aggregation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 43ee42f..59d3684 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -162,7 +162,12 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: # Label was already in original event, drop the message. pass elif self.event.action == self.action_type.X_LABEL_AGGREGATE: - if evt.action == self.action_type.LABELED: + if evt.action == self.action_type.OPENED: + # Switch over to an issue create aggregation. + self.aggregation = {} + self.event = evt + self.start_open_label_dropping() + elif evt.action == self.action_type.LABELED: self.aggregation["added_labels"].append(evt.label) elif evt.action == self.action_type.UNLABELED: self.aggregation["removed_labels"].append(evt.label) From c6d768b2ec3657779e927646ad1f872d624f34df Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 28 Jan 2025 17:40:06 +0200 Subject: [PATCH 37/43] Use per-message profiles for webhooks --- github/avatar_manager.py | 51 +++++++++++++++++++++++++++++++++++++++ github/bot.py | 3 +++ github/db.py | 2 +- github/webhook/handler.py | 13 ++++++++++ 4 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 github/avatar_manager.py diff --git a/github/avatar_manager.py b/github/avatar_manager.py new file mode 100644 index 0000000..8f2681a --- /dev/null +++ b/github/avatar_manager.py @@ -0,0 +1,51 @@ +from typing import TYPE_CHECKING +import asyncio + +from sqlalchemy import MetaData, Table, Column, Text +from sqlalchemy.engine.base import Engine + +from mautrix.types import ContentURI + +if TYPE_CHECKING: + from .bot import GitHubBot + + +class AvatarManager: + bot: 'GitHubBot' + _avatars: dict[str, ContentURI] + _table: Table + _db: Engine + _lock: asyncio.Lock + + def __init__(self, bot: 'GitHubBot', metadata: MetaData) -> None: + self.bot = bot + self._db = bot.database + self._table = Table("avatar", metadata, + Column("url", Text, primary_key=True), + Column("mxc", Text, nullable=False)) + self._lock = asyncio.Lock() + self._avatars = {} + + def load_db(self) -> None: + self._avatars = {url: ContentURI(mxc) + for url, mxc + in self._db.execute(self._table.select())} + + async def get_mxc(self, url: str) -> ContentURI: + try: + return self._avatars[url] + except KeyError: + pass + async with self.bot.http.get(url) as resp: + resp.raise_for_status() + data = await resp.read() + async with self._lock: + try: + return self._avatars[url] + except KeyError: + pass + mxc = await self.bot.client.upload_media(data) + self._avatars[url] = mxc + with self._db.begin() as conn: + conn.execute(self._table.insert().values(url=url, mxc=mxc)) + return mxc diff --git a/github/bot.py b/github/bot.py index dd1f423..1e67226 100644 --- a/github/bot.py +++ b/github/bot.py @@ -26,6 +26,7 @@ from .api import GitHubWebhookReceiver from .commands import Commands from .config import Config +from .avatar_manager import AvatarManager class GitHubBot(Plugin): @@ -33,6 +34,7 @@ class GitHubBot(Plugin): webhook_receiver: GitHubWebhookReceiver webhook_manager: WebhookManager webhook_handler: WebhookHandler + avatars: AvatarManager clients: ClientManager commands: Commands config: Config @@ -48,6 +50,7 @@ async def start(self) -> None: self.webhook_manager = WebhookManager(self.config["webhook_key"], self.database, metadata) self.webhook_handler = WebhookHandler(bot=self) + self.avatars = AvatarManager(bot=self, metadata=metadata) self.webhook_receiver = GitHubWebhookReceiver(handler=self.webhook_handler, secrets=self.webhook_manager, global_secret=self.config["global_webhook_secret"]) diff --git a/github/db.py b/github/db.py index 8753c96..ca87efd 100644 --- a/github/db.py +++ b/github/db.py @@ -23,7 +23,7 @@ from sqlalchemy.engine.base import Engine from sqlalchemy.ext.declarative import declarative_base -from mautrix.types import UserID, EventID, RoomID +from mautrix.types import UserID, EventID, RoomID, ContentURI Base = declarative_base() diff --git a/github/webhook/handler.py b/github/webhook/handler.py index b55cedd..2e79ae3 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -47,6 +47,7 @@ expand_enum, ACTION_CLASSES, OTHER_ENUMS, + User, ) from .manager import WebhookInfo from .aggregation import PendingAggregation @@ -182,6 +183,18 @@ def abort() -> None: formatted_body=html, body=await parse_html(html.strip()), ) + if hasattr(evt, "sender") and isinstance(evt.sender, User): + mxc = "" + if evt.sender.avatar_url: + try: + mxc = await self.bot.avatars.get_mxc(evt.sender.avatar_url) + except Exception: + self.log.warning("Failed to get avatar URL", exc_info=True) + content["com.beeper.per_message_profile"] = { + "id": str(evt.sender.id), + "displayname": evt.sender.login, + "avatar_url": mxc, + } content["xyz.maubot.github.webhook"] = { "delivery_ids": list(delivery_ids), "event_type": str(evt_type), From e09d876e29c91dc3fea5be534a9565f512143d8f Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Tue, 28 Jan 2025 17:40:38 +0200 Subject: [PATCH 38/43] Bump version to v0.2.0 --- maubot.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maubot.yaml b/maubot.yaml index 30eb78a..81e81df 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -1,6 +1,6 @@ maubot: 0.4.1 id: xyz.maubot.github -version: 0.1.2 +version: 0.2.0 license: AGPL-3.0-or-later modules: - github From 38aacbe1efb25dc4242aba7728678357b0c526e5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 May 2025 16:33:02 +0300 Subject: [PATCH 39/43] Update to asyncpg and allow migrating webhooks to different rooms --- github/api/webhook.py | 52 ++++---- github/avatar_manager.py | 19 ++- github/bot.py | 34 +++--- github/client_manager.py | 39 +++--- github/commands.py | 14 +-- github/db.py | 216 +++++++++++++++++++++++++++++----- github/migrations.py | 70 +++++++++++ github/util/__init__.py | 1 - github/util/sql_uuid.py | 45 ------- github/webhook/__init__.py | 2 +- github/webhook/aggregation.py | 2 +- github/webhook/handler.py | 21 ++-- github/webhook/manager.py | 148 ++++++----------------- maubot.yaml | 1 + 14 files changed, 372 insertions(+), 292 deletions(-) create mode 100644 github/migrations.py delete mode 100644 github/util/sql_uuid.py diff --git a/github/api/webhook.py b/github/api/webhook.py index fb7d8c4..948d0ce 100644 --- a/github/api/webhook.py +++ b/github/api/webhook.py @@ -13,65 +13,52 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, TYPE_CHECKING +from typing import Optional, Protocol, TYPE_CHECKING from uuid import UUID import hashlib import hmac import json from aiohttp import web +from attr import dataclass from mautrix.types import SerializerError, RoomID from maubot.handlers import web as web_handler from .types import EventType, Event, EVENT_CLASSES +from ..db import WebhookInfo if TYPE_CHECKING: - # Python 3.8+ only, so we do this in TYPE_CHECKING only - from typing import Protocol + from ..webhook import WebhookManager - class WebhookInfo(Protocol): - secret: str +class HandlerFunc(Protocol): + async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, + info: WebhookInfo) -> None: + pass - class HandlerFunc(Protocol): - async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, - info: WebhookInfo) -> None: - pass - - - class SecretDict(Protocol): - def __getitem__(self, item: str) -> WebhookInfo: - pass +@dataclass(frozen=True) +class GlobalWebhookInfo(WebhookInfo): + room_id: RoomID + secret: str -class GlobalWebhookInfo: id: UUID = UUID(int=0) user_id: str = "root" - github_id: Optional[int] = None + github_id: int | None = None repo: str = "unknown" - room_id: RoomID - secret: str - - def __init__(self, room_id: RoomID, secret: str) -> None: - self.room_id = room_id - self.secret = secret - - def __repr__(self) -> str: - return f"GlobalWebhookInfo(room_id={self.room_id!r})" - def __str__(self) -> str: return f"global webhook for {self.room_id!r}" class GitHubWebhookReceiver: handler: 'HandlerFunc' - secrets: 'SecretDict' + secrets: "WebhookManager" global_secret: Optional[str] - def __init__(self, handler: 'HandlerFunc', secrets: 'SecretDict', + def __init__(self, handler: 'HandlerFunc', secrets: "WebhookManager", global_secret: Optional[str]) -> None: self.handler = handler self.secrets = secrets @@ -85,13 +72,16 @@ async def handle_global(self, request: web.Request) -> web.Response: room_id = RoomID(request.query["room"]) except KeyError: return web.Response(status=400, text="room query param missing") - return await self._handle(request, GlobalWebhookInfo(room_id, self.global_secret)) + return await self._handle(request, GlobalWebhookInfo(room_id=room_id, secret=self.global_secret)) @web_handler.post("/webhook/{id}") async def handle(self, request: web.Request) -> web.Response: try: - webhook_info = self.secrets[request.match_info["id"]] - except (ValueError, KeyError): + id = UUID(request.match_info["id"]) + except ValueError: + return web.Response(status=404, text="Invalid webhook ID") + webhook_info = await self.secrets.get(id) + if webhook_info is None: return web.Response(status=404, text="Webhook not found") return await self._handle(request, webhook_info) diff --git a/github/avatar_manager.py b/github/avatar_manager.py index 8f2681a..fd35985 100644 --- a/github/avatar_manager.py +++ b/github/avatar_manager.py @@ -6,6 +6,8 @@ from mautrix.types import ContentURI +from .db import DBManager + if TYPE_CHECKING: from .bot import GitHubBot @@ -13,23 +15,19 @@ class AvatarManager: bot: 'GitHubBot' _avatars: dict[str, ContentURI] - _table: Table - _db: Engine + _db: DBManager _lock: asyncio.Lock - def __init__(self, bot: 'GitHubBot', metadata: MetaData) -> None: + def __init__(self, bot: 'GitHubBot') -> None: self.bot = bot - self._db = bot.database - self._table = Table("avatar", metadata, - Column("url", Text, primary_key=True), - Column("mxc", Text, nullable=False)) + self._db = bot.db self._lock = asyncio.Lock() self._avatars = {} - def load_db(self) -> None: + async def load_db(self) -> None: self._avatars = {url: ContentURI(mxc) for url, mxc - in self._db.execute(self._table.select())} + in await self._db.get_avatars()} async def get_mxc(self, url: str) -> ContentURI: try: @@ -46,6 +44,5 @@ async def get_mxc(self, url: str) -> ContentURI: pass mxc = await self.bot.client.upload_media(data) self._avatars[url] = mxc - with self._db.begin() as conn: - conn.execute(self._table.insert().values(url=url, mxc=mxc)) + await self._db.put_avatar(url, mxc) return mxc diff --git a/github/bot.py b/github/bot.py index 1e67226..2606034 100644 --- a/github/bot.py +++ b/github/bot.py @@ -15,12 +15,13 @@ # along with this program. If not, see . from typing import Type -from sqlalchemy import MetaData +from mautrix.util.async_db import UpgradeTable from maubot import Plugin from mautrix.util import background_task -from .db import Database +from .db import DBManager +from .migrations import upgrade_table from .webhook import WebhookManager, WebhookHandler from .client_manager import ClientManager from .api import GitHubWebhookReceiver @@ -30,7 +31,7 @@ class GitHubBot(Plugin): - db: Database + db: DBManager webhook_receiver: GitHubWebhookReceiver webhook_manager: WebhookManager webhook_handler: WebhookHandler @@ -42,22 +43,23 @@ class GitHubBot(Plugin): async def start(self) -> None: self.config.load_and_update() - metadata = MetaData() - - self.db = Database(self.database) + self.db = DBManager(self.database) + if await self.database.table_exists("needs_post_migration"): + self.log.info("Running database post-migration") + async with self.database.acquire() as conn, conn.transaction(): + await self.db.run_post_migration(conn, self.config["webhook_key"]) self.clients = ClientManager(self.config["client_id"], self.config["client_secret"], - self.http, self.database, metadata) - self.webhook_manager = WebhookManager(self.config["webhook_key"], - self.database, metadata) + self.http, self.db) + self.webhook_manager = WebhookManager(self.db) self.webhook_handler = WebhookHandler(bot=self) - self.avatars = AvatarManager(bot=self, metadata=metadata) + self.avatars = AvatarManager(bot=self) self.webhook_receiver = GitHubWebhookReceiver(handler=self.webhook_handler, secrets=self.webhook_manager, global_secret=self.config["global_webhook_secret"]) self.commands = Commands(bot=self) - metadata.create_all(self.database) - self.clients.load_db() + await self.clients.load_db() + await self.avatars.load_db() self.register_handler_class(self.webhook_receiver) self.register_handler_class(self.clients) @@ -82,10 +84,10 @@ async def _reset_tokens(self) -> None: else: if new_token is None: self.log.debug(f"{user_id}'s token was not valid, removing from database") - self.clients.remove(user_id) + await self.clients.remove(user_id) else: self.log.debug(f"Successfully reset {user_id}'s token") - self.clients.put(user_id, new_token) + await self.clients.put(user_id, new_token) self.log.debug("Finished resetting all user tokens") def on_external_config_update(self) -> None: @@ -101,3 +103,7 @@ def on_external_config_update(self) -> None: @classmethod def get_config_class(cls) -> Type[Config]: return Config + + @classmethod + def get_db_upgrade_table(cls) -> UpgradeTable: + return upgrade_table diff --git a/github/client_manager.py b/github/client_manager.py index 52a0d38..0af4835 100644 --- a/github/client_manager.py +++ b/github/client_manager.py @@ -13,41 +13,33 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Optional - -from sqlalchemy import MetaData, Table, Column, String -from sqlalchemy.engine.base import Engine from aiohttp import web, ClientError, ClientSession from mautrix.types import UserID from maubot.handlers import web as web_handler from .api import GitHubClient +from .db import DBManager class ClientManager: client_id: str client_secret: str - _clients: Dict[UserID, GitHubClient] - _table: Table - _db: Engine + _clients: dict[UserID, GitHubClient] + _db: DBManager _http: ClientSession - def __init__(self, client_id: str, client_secret: str, http: ClientSession, - db: Engine, metadata: MetaData): + def __init__(self, client_id: str, client_secret: str, http: ClientSession, db: DBManager): self.client_id = client_id self.client_secret = client_secret self._db = db self._http = http - self._table = Table("client", metadata, - Column("user_id", String(255), primary_key=True), - Column("token", String(255), nullable=False)) self._clients = {} - def load_db(self) -> None: + async def load_db(self) -> None: self._clients = {user_id: self._make(token) for user_id, token - in self._db.execute(self._table.select())} + in await self._db.get_clients()} def _make(self, token: str) -> GitHubClient: return GitHubClient(http=self._http, @@ -55,20 +47,17 @@ def _make(self, token: str) -> GitHubClient: client_secret=self.client_secret, token=token) - def put(self, user_id: UserID, token: str) -> None: - with self._db.begin() as conn: - conn.execute(self._table.delete().where(self._table.c.user_id == user_id)) - conn.execute(self._table.insert().values(user_id=user_id, token=token)) + async def put(self, user_id: UserID, token: str) -> None: + await self._db.put_client(user_id, token) - def remove(self, user_id: UserID) -> None: - with self._db.begin() as conn: - self._clients.pop(user_id, None) - conn.execute(self._table.delete().where(self._table.c.user_id == user_id)) + async def remove(self, user_id: UserID) -> None: + self._clients.pop(user_id, None) + await self._db.delete_client(user_id) - def get_all(self) -> Dict[UserID, GitHubClient]: + def get_all(self) -> dict[UserID, GitHubClient]: return self._clients.copy() - def get(self, user_id: UserID, create: bool = False) -> Optional[GitHubClient]: + def get(self, user_id: UserID, create: bool = False) -> GitHubClient | None: try: return self._clients[user_id] except KeyError: @@ -108,5 +97,5 @@ async def login_callback(self, request: web.Request) -> web.Response: return web.Response(status=401, text="Failed to finish login") resp = await client.query("viewer { login }") user = resp["viewer"]["login"] - self.put(user_id, client.token) + await self.put(user_id, client.token) return web.Response(status=200, text=f"Logged in as {user}") diff --git a/github/commands.py b/github/commands.py index 398f709..6442a78 100644 --- a/github/commands.py +++ b/github/commands.py @@ -133,7 +133,7 @@ async def login(self, evt: MessageEvent, flags: str, client: Optional[GitHubClie @authenticated async def logout(self, evt: MessageEvent, client: GitHubClient) -> None: await client.delete_token() - self.bot.clients.remove(evt.sender) + await self.bot.clients.remove(evt.sender) await evt.reply("Successfully logged out") @event.on(EventType.ROOM_MESSAGE) @@ -240,7 +240,7 @@ async def webhook(self, evt: MessageEvent, repo: Tuple[str, str], client: GitHub @webhook.subcommand("list", aliases=["ls", "l"], help="List webhooks in this room.") async def webhook_list(self, evt: MessageEvent) -> None: - hooks = self.bot.webhook_manager.get_all_for_room(evt.room_id) + hooks = await self.bot.webhook_manager.get_all_for_room(evt.room_id) info = "\n".join(f"* `{hook.repo}` added by " f"[{hook.user_id}](https://matrix.to/#/{hook.user_id})" for hook in hooks) @@ -252,12 +252,12 @@ async def webhook_list(self, evt: MessageEvent) -> None: async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient ) -> None: repo_name = f"{repo[0]}/{repo[1]}" - existing = self.bot.webhook_manager.find(repo_name, evt.room_id) + existing = await self.bot.webhook_manager.get_by_repo(repo_name, evt.room_id) if existing: await evt.reply("This room already has a webhook for that repo") # TODO webhook may be deleted on github side return - webhook = self.bot.webhook_manager.create(repo_name, evt.sender, evt.room_id) + webhook = await self.bot.webhook_manager.create(repo_name, evt.sender, evt.room_id) try: await client.create_webhook( *repo, url=self.bot.webapp_url / "webhook" / str(webhook.id), @@ -267,7 +267,7 @@ async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: ) except GitHubError as e: await evt.reply(f"Failed to create webhook: {e.message}") - self.bot.webhook_manager.delete(webhook.id) + await self.bot.webhook_manager.delete(webhook.id) else: await evt.reply(f"Successfully created webhook for {repo_name}") @@ -277,11 +277,11 @@ async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: async def webhook_remove(self, evt: MessageEvent, repo: Tuple[str, str], client: Optional[GitHubClient]) -> None: repo_name = f"{repo[0]}/{repo[1]}" - webhook_info = self.bot.webhook_manager.find(repo_name, evt.room_id) + webhook_info = await self.bot.webhook_manager.get_by_repo(repo_name, evt.room_id) if not webhook_info: await evt.reply("This room does not have a webhook for that repo") return - self.bot.webhook_manager.delete(webhook_info.id) + await self.bot.webhook_manager.delete(webhook_info.id) if webhook_info.github_id: if client: try: diff --git a/github/db.py b/github/db.py index ca87efd..2fed17b 100644 --- a/github/db.py +++ b/github/db.py @@ -1,6 +1,6 @@ # github - A maubot plugin to act as a GitHub client and webhook receiver. # Copyright (C) 2022 Sumner Evans -# Copyright (C) 2022 Tulir Asokan +# Copyright (C) 2025 Tulir Asokan # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by @@ -14,48 +14,200 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, NamedTuple, Optional -import logging as log +import uuid +import hmac +import hashlib +from typing import Optional -from sqlalchemy import Column, String, Text, ForeignKeyConstraint, or_, ForeignKey -from sqlalchemy.orm import sessionmaker, relationship, Session -from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound -from sqlalchemy.engine.base import Engine -from sqlalchemy.ext.declarative import declarative_base +from asyncpg import Record +from attr import dataclass from mautrix.types import UserID, EventID, RoomID, ContentURI +from mautrix.util.async_db import Database, Connection -Base = declarative_base() +@dataclass(frozen=True) +class Client: + user_id: UserID + token: str -class MatrixMessage(Base): - __tablename__ = "matrix_message" + @classmethod + def from_row(cls, row: Record | None) -> Optional["Client"]: + if not row: + return None + user_id = row["user_id"] + token = row["token"] + return cls( + user_id=user_id, + token=token, + ) - message_id: str = Column(String(255), primary_key=True) - room_id: RoomID = Column(String(255), primary_key=True) - event_id: EventID = Column(String(255), nullable=False) +@dataclass(frozen=True) +class Avatar: + url: str + mxc: ContentURI -class Database: - db: Engine + @classmethod + def from_row(cls, row: Record | None) -> Optional["Avatar"]: + if not row: + return None + url = row["url"] + mxc = row["mxc"] + return cls( + url=url, + mxc=mxc, + ) - def __init__(self, db: Engine) -> None: +@dataclass(frozen=True) +class WebhookInfo: + id: uuid.UUID + repo: str + user_id: UserID + room_id: RoomID + secret: str + github_id: int | None = None + + @classmethod + def from_row(cls, row: Record | None) -> Optional["WebhookInfo"]: + if not row: + return None + id = row["id"] + repo = row["repo"] + user_id = row["user_id"] + room_id = row["room_id"] + github_id = row["github_id"] + secret = row["secret"] + return cls( + id=uuid.UUID(id), + repo=repo, + user_id=user_id, + room_id=room_id, + github_id=github_id, + secret=secret, + ) + + def __str__(self) -> str: + return (f"webhook {self.id!s} (GH{self.github_id}) from {self.repo} to {self.room_id}" + f" added by {self.user_id}") + + +class DBManager: + db: Database + + def __init__(self, db: Database) -> None: self.db = db - Base.metadata.create_all(db) - self.Session = sessionmaker(bind=self.db) - def get_event(self, message_id: str, room_id: RoomID) -> Optional[EventID]: - s: Session = self.Session() - event = s.query(MatrixMessage).get((message_id, room_id)) - return event.event_id if event else None + async def get_event(self, message_id: str, room_id: RoomID) -> EventID | None: + return await self.db.fetchval( + "SELECT event_id FROM matrix_message WHERE message_id = $1 AND room_id = $2", + message_id, room_id, + ) - def put_event( - self, message_id: str, room_id: RoomID, event_id: EventID, merge: bool = False + async def put_event( + self, message_id: str, room_id: RoomID, event_id: EventID, ) -> None: - s: Session = self.Session() - evt = MatrixMessage(message_id=message_id, room_id=room_id, event_id=event_id) - if merge: - s.merge(evt) - else: - s.add(evt) - s.commit() + await self.db.execute( + """ + INSERT INTO matrix_message (message_id, room_id, event_id) VALUES ($1, $2, $3) + ON CONFLICT (message_id, room_id) DO UPDATE SET event_id = excluded.event_id + """, + message_id, room_id, event_id, + ) + + async def get_clients(self) -> list[Client]: + rows = await self.db.fetch("SELECT user_id, token FROM client") + return [Client.from_row(row) for row in rows] + + async def put_client(self, user_id: UserID, token: str) -> None: + await self.db.execute( + """ + INSERT INTO client (user_id, token) VALUES ($1, $2) + ON CONFLICT (user_id) DO UPDATE SET token = excluded.token + """, + user_id, token, + ) + + async def delete_client(self, user_id: UserID) -> None: + await self.db.execute( + "DELETE FROM client WHERE user_id = $1", user_id, + ) + + async def get_avatars(self) -> list[Avatar]: + rows = await self.db.fetch("SELECT url, mxc FROM avatar") + return [Avatar.from_row(row) for row in rows] + + async def put_avatar(self, url: str, mxc: ContentURI) -> None: + await self.db.execute( + """ + INSERT INTO avatar (url, mxc) VALUES ($1, $2) + ON CONFLICT (url) DO NOTHING + """, + url, mxc, + ) + + async def get_webhook_by_id(self, id: uuid.UUID) -> WebhookInfo | None: + row = await self.db.fetchrow( + "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE id = $1", id, + ) + return WebhookInfo.from_row(row) + + async def get_webhook_by_repo(self, room_id: RoomID, repo: str) -> WebhookInfo | None: + row = await self.db.fetchrow( + "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE room_id = $1 AND repo = $2", room_id, repo, + ) + return WebhookInfo.from_row(row) + + async def get_webhooks_in_room(self, room_id: RoomID) -> list[WebhookInfo]: + rows = await self.db.fetch( + "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE room_id = $1", room_id, + ) + return [WebhookInfo.from_row(row) for row in rows] + + async def delete_webhook(self, id: uuid.UUID) -> None: + await self.db.execute( + "DELETE FROM webhook WHERE id = $1", id, + ) + + async def insert_webhook(self, webhook: WebhookInfo, *, _conn: Connection | None = None) -> None: + await (_conn or self.db).execute( + """ + INSERT INTO webhook (id, repo, user_id, room_id, secret, github_id) + VALUES ($1, $2, $3, $4, $5, $6) + """, + *webhook, + ) + + async def set_webhook_github_id(self, id: uuid.UUID, github_id: int) -> None: + await self.db.execute( + "UPDATE webhook SET github_id = $1 WHERE id = $2", github_id, id, + ) + + async def transfer_webhook_repo(self, id: uuid.UUID, new_repo: str) -> None: + await self.db.execute( + "UPDATE webhook SET repo = $1 WHERE id = $2", new_repo, id, + ) + + async def transfer_webhook_rooms(self, old_room: RoomID, new_room: RoomID) -> None: + await self.db.execute( + "UPDATE webhook SET room_id = $1 WHERE room_id = $2 ON CONFLICT (repo, room_id) DO NOTHING", new_room, old_room, + ) + + async def run_post_migration(self, conn: Connection, secret_key: str) -> None: + rows = await self.db.fetch("SELECT id, repo, user_id, room_id, github_id FROM webhook_old") + for row in rows: + id = uuid.UUID(row["id"]) + secret = hmac.new(key=secret_key.encode("utf-8"), digestmod=hashlib.sha256) + secret.update(id.bytes) + secret.update(row["user_id"].encode("utf-8")) + secret.update(row["room_id"].encode("utf-8")) + new_webhook = WebhookInfo( + id=id, + repo=row["repo"], + user_id=row["user_id"], + room_id=row["room_id"], + github_id=row["github_id"], + secret=secret.hexdigest(), + ) + await self.insert_webhook(new_webhook, _conn=conn) + await conn.execute("DROP TABLE needs_post_migration") diff --git a/github/migrations.py b/github/migrations.py new file mode 100644 index 0000000..9c7c1f0 --- /dev/null +++ b/github/migrations.py @@ -0,0 +1,70 @@ +# github - A maubot plugin to act as a GitHub client and webhook receiver. +# Copyright (C) 2025 Tulir Asokan +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Affero General Public License for more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with this program. If not, see . +from mautrix.util.async_db import Connection, Scheme, UpgradeTable + +upgrade_table = UpgradeTable() + +@upgrade_table.register(description="Latest revision", upgrades_to=1) +async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: + needs_migration = False + if await conn.table_exists("webhook"): + needs_migration = True + await conn.execute(""" + ALTER TABLE webhook RENAME TO webhook_old; + ALTER TABLE client RENAME TO client_old; + ALTER TABLE matrix_message RENAME TO matrix_message_old; + """) + await conn.execute( + f"""CREATE TABLE client ( + user_id TEXT NOT NULL, + token TEXT NOT NULL, + PRIMARY KEY (user_id) + )""" + ) + await conn.execute( + """CREATE TABLE webhook ( + id uuid NOT NULL, + repo TEXT NOT NULL, + user_id TEXT NOT NULL, + room_id TEXT NOT NULL, + secret TEXT NOT NULL, + github_id INTEGER, + PRIMARY KEY (id), + CONSTRAINT webhook_repo_room_unique UNIQUE (repo, room_id) + )""" + ) + await conn.execute( + """CREATE TABLE matrix_message ( + message_id TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + PRIMARY KEY (message_id, room_id) + )""" + ) + await conn.execute( + """CREATE TABLE IF NOT EXISTS avatar ( + url TEXT NOT NULL, + mxc TEXT NOT NULL, + PRIMARY KEY (url) + )""" + ) + if needs_migration: + await migrate_legacy_to_v1(conn) + +async def migrate_legacy_to_v1(conn: Connection) -> None: + await conn.execute("INSERT INTO client (user_id, token) SELECT user_id, token FROM client_old") + await conn.execute("INSERT INTO matrix_message (message_id, room_id, event_id) SELECT message_id, room_id, event_id FROM matrix_message_old") + await conn.execute("CREATE TABLE needs_post_migration(noop INTEGER PRIMARY KEY)") diff --git a/github/util/__init__.py b/github/util/__init__.py index ecbf996..b02c41d 100644 --- a/github/util/__init__.py +++ b/github/util/__init__.py @@ -1,3 +1,2 @@ -from .sql_uuid import UUIDType from .recursive_get import recursive_get from .contrast import contrast, hex_to_rgb, rgb_to_hex diff --git a/github/util/sql_uuid.py b/github/util/sql_uuid.py deleted file mode 100644 index c80dcc6..0000000 --- a/github/util/sql_uuid.py +++ /dev/null @@ -1,45 +0,0 @@ -# Based on https://docs.sqlalchemy.org/en/13/core/custom_types.html#backend-agnostic-guid-type -from typing import Union, Optional, Type, Any - -from sqlalchemy import types -from sqlalchemy.dialects import postgresql -import uuid - -InputUUID = Union[uuid.UUID, str, None] - - -class UUIDType(types.TypeDecorator): - """Platform-independent UUID type. - - Uses PostgreSQL's UUID type, otherwise uses - CHAR(32), storing as stringified hex values. - """ - - impl = types.CHAR - - @property - def python_type(self) -> Type[uuid.UUID]: - return uuid.UUID - - def process_literal_param(self, value: InputUUID, dialect: Any) -> Optional[str]: - raise NotImplementedError() - - def load_dialect_impl(self, dialect: Any) -> Any: - if dialect.name == "postgresql": - return dialect.type_descriptor(postgresql.UUID()) - else: - return dialect.type_descriptor(types.CHAR(32)) - - def process_bind_param(self, value: InputUUID, dialect: Any) -> Optional[str]: - if value is None: - return None - elif dialect.name == "postgresql": - return str(value) - elif not isinstance(value, uuid.UUID): - return uuid.UUID(value).hex - return value.hex - - def process_result_value(self, value: InputUUID, dialect: Any) -> Optional[uuid.UUID]: - if value is not None and not isinstance(value, uuid.UUID): - value = uuid.UUID(value) - return value diff --git a/github/webhook/__init__.py b/github/webhook/__init__.py index 81d4319..0ac3076 100644 --- a/github/webhook/__init__.py +++ b/github/webhook/__init__.py @@ -1,3 +1,3 @@ -from .manager import WebhookInfo, WebhookManager from .handler import WebhookHandler from .aggregation import PendingAggregation +from .manager import WebhookInfo, WebhookManager diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index 59d3684..e737c3c 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -131,7 +131,7 @@ async def _send(self) -> None: aggregation=self.aggregation, ) if self.event_type == EventType.PUSH and event_id: - self.handler.bot.db.put_event(self.event.message_id, self.webhook_info.room_id, event_id) + await self.handler.bot.db.put_event(self.event.message_id, self.webhook_info.room_id, event_id) def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: postpone = True diff --git a/github/webhook/handler.py b/github/webhook/handler.py index 2e79ae3..a4755ec 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -82,23 +82,22 @@ def reload_config(self) -> None: self.msgtype = MessageType(self.bot.config["message_options.msgtype"]) or MessageType.NOTICE PendingAggregation.timeout = int(self.bot.config["message_options.aggregation_timeout"]) - async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info: WebhookInfo - ) -> None: + async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info: WebhookInfo) -> None: if evt_type == EventType.PING: self.log.debug(f"Received ping for {info}: {evt.zen}") - self.bot.webhook_manager.set_github_id(info, evt.hook_id) + await self.bot.webhook_manager.set_github_id(info, evt.hook_id) elif evt_type == EventType.META and evt.action == MetaAction.DELETED: self.log.debug(f"Received delete hook for {info}") - self.bot.webhook_manager.delete(info.id) + await self.bot.webhook_manager.delete(info.id) elif evt_type == EventType.REPOSITORY: if evt.action in (RepositoryAction.TRANSFERRED, RepositoryAction.RENAMED): action = "transfer" if evt.action == RepositoryAction.TRANSFERRED else "rename" name = evt.repository.full_name self.log.debug(f"Received {action} hook {info} -> {name}") - self.bot.webhook_manager.transfer(info, name) + await self.bot.webhook_manager.transfer_repo(info, name) elif evt.action == RepositoryAction.DELETED: self.log.debug(f"Received repo delete hook for {info}") - self.bot.webhook_manager.delete(info.id) + await self.bot.webhook_manager.delete(info.id) elif evt_type == EventType.PUSH and (evt.size is None or evt.distinct_size is None): assert isinstance(evt, PushEvent) evt.size = len(evt.commits) @@ -107,7 +106,7 @@ async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info if evt.workflow_job.name == "lock-stale": return assert isinstance(evt, WorkflowJobEvent) - push_evt = self.bot.db.get_event(evt.push_id, info.room_id) + push_evt = await self.bot.db.get_event(evt.push_id, info.room_id) if not push_evt: self.bot.log.debug(f"No message found to react to push {evt.push_id}") return @@ -123,21 +122,19 @@ async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info **evt.meta, } - prev_reaction = self.bot.db.get_event(evt.reaction_id, info.room_id) + prev_reaction = await self.bot.db.get_event(evt.reaction_id, info.room_id) if prev_reaction: await self.bot.client.redact(info.room_id, prev_reaction) event_id = await self.bot.client.send_message_event( info.room_id, MautrixEventType.REACTION, reaction ) - self.bot.db.put_event( - evt.reaction_id, info.room_id, event_id, merge=prev_reaction is not None - ) + await self.bot.db.put_event(evt.reaction_id, info.room_id, event_id) if PendingAggregation.timeout < 0: # Aggregations are disabled event_id = await self.send_message(evt_type, evt, info.room_id, {delivery_id}) if evt_type == EventType.PUSH and event_id: - self.bot.db.put_event(evt.message_id, info.room_id, event_id) + await self.bot.db.put_event(evt.message_id, info.room_id, event_id) return for pending in self.pending_aggregations[info.id]: diff --git a/github/webhook/manager.py b/github/webhook/manager.py index dfb6d93..0aa24d1 100644 --- a/github/webhook/manager.py +++ b/github/webhook/manager.py @@ -13,144 +13,68 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Union, Optional, Any, Generator from uuid import UUID, uuid4 -import hashlib -import hmac - -from sqlalchemy import MetaData, Table, Column, String, Integer, UniqueConstraint, and_ -from sqlalchemy.engine.base import Engine +import random from mautrix.types import UserID, RoomID -from ..util import UUIDType - - -class WebhookInfo: - __slots__ = ("id", "repo", "user_id", "room_id", "github_id", "_secret_key", "__initialized") - id: UUID - repo: str - user_id: UserID - room_id: RoomID - github_id: Optional[int] - - def __init__(self, id: UUID, repo: str, user_id: UserID, room_id: RoomID, - github_id: Optional[int] = None, _secret_key: bytes = None) -> None: - self.id = id - self.repo = repo - self.user_id = user_id - self.room_id = room_id - self.github_id = github_id - self._secret_key = _secret_key - self.__initialized = True - - def __repr__(self) -> str: - return (f"WebhookInfo(id={self.id!r}, repo={self.repo!r}, user_id={self.user_id!r}," - f" room_id={self.room_id!r}, github_id={self.github_id!r})") - - def __str__(self) -> str: - return (f"webhook {self.id!s} (GH{self.github_id}) from {self.repo} to {self.room_id}" - f" added by {self.user_id}") - - def __delattr__(self, item) -> None: - raise ValueError("Can't change attributes after initialization") - - def __setattr__(self, key: str, value: Any) -> None: - if hasattr(self, "__initialized"): - raise ValueError("Can't change attributes after initialization") - super().__setattr__(key, value) - - @property - def secret(self) -> str: - secret = hmac.new(key=self._secret_key, digestmod=hashlib.sha256) - secret.update(self.id.bytes) - secret.update(self.user_id.encode("utf-8")) - secret.update(self.room_id.encode("utf-8")) - return secret.hexdigest() +from ..db import DBManager, WebhookInfo class WebhookManager: - _table: Table - _db: Engine - _secret: bytes - _webhooks: Dict[UUID, WebhookInfo] + _db: DBManager + _webhooks: dict[UUID, WebhookInfo] - def __init__(self, secret: str, db: Engine, metadata: MetaData): - self._secret = secret.encode("utf-8") + def __init__(self, db: DBManager): self._db = db - self._table = Table("webhook", metadata, - Column("id", UUIDType, primary_key=True), - Column("repo", String(255), nullable=False), - Column("user_id", String(255), nullable=False), - Column("room_id", String(255), nullable=False), - Column("github_id", Integer, nullable=True), - UniqueConstraint("repo", "room_id")) self._webhooks = {} - def create(self, repo: str, user_id: UserID, room_id: RoomID) -> WebhookInfo: + async def create(self, repo: str, user_id: UserID, room_id: RoomID) -> WebhookInfo: info = WebhookInfo(id=uuid4(), repo=repo, user_id=user_id, room_id=room_id, - _secret_key=self._secret) - self._db.execute(self._table.insert().values( - id=info.id, github_id=info.github_id, repo=repo, - user_id=info.user_id, room_id=info.room_id)) + secret=random.randbytes(16).hex()) + await self._db.insert_webhook(info) self._webhooks[info.id] = info return info - def set_github_id(self, info: WebhookInfo, github_id: int) -> WebhookInfo: - self._db.execute(self._table.update() - .where(self._table.c.id == info.id) - .values(github_id=github_id)) - return self._select(info.id) + async def set_github_id(self, info: WebhookInfo, github_id: int) -> WebhookInfo: + await self._db.set_webhook_github_id(info.id, github_id) + return await self.get_by_id(info.id) - def transfer(self, info: WebhookInfo, new_name: str) -> WebhookInfo: - self._db.execute(self._table.update() - .where(self._table.c.id == info.id) - .values(repo=new_name)) - return self._select(info.id) + async def transfer_repo(self, info: WebhookInfo, new_name: str) -> WebhookInfo: + await self._db.transfer_webhook_repo(info.id, new_name) + return await self.get_by_id(info.id) - def delete(self, id: UUID) -> None: - self._db.execute(self._table.delete().where(self._table.c.id == id)) - try: - del self._webhooks[id] - except KeyError: - pass + async def transfer_rooms(self, old_room: RoomID, new_room: RoomID) -> list[WebhookInfo]: + await self._db.transfer_webhook_rooms(old_room, new_room) + return await self.get_all_for_room(new_room) - def _execute_select(self, *where_clause) -> Optional[WebhookInfo]: - rows = self._db.execute(self._table.select().where(where_clause[0] if len(where_clause) == 1 - else and_(*where_clause))) - try: - info = WebhookInfo(*next(rows), _secret_key=self._secret) - self._webhooks[info.id] = info - return info - except StopIteration: - return None + async def delete(self, id: UUID) -> None: + self._webhooks.pop(id, None) + await self._db.delete_webhook(id) - def _select(self, id: UUID) -> Optional[WebhookInfo]: - return self._execute_select(self._table.c.id == id) + def _add_to_cache(self, info: WebhookInfo | None) -> WebhookInfo | None: + if info is None: + return None + self._webhooks[info.id] = info + return info - def get(self, id: UUID) -> Optional[WebhookInfo]: + async def get(self, id: UUID) -> WebhookInfo | None: try: return self._webhooks[id] except KeyError: - return self._select(id) - - def get_all_for_room(self, room_id: RoomID) -> Generator[WebhookInfo, None, None]: - rows = self._db.execute(self._table.select().where(self._table.c.room_id == room_id)) - return (WebhookInfo(*row, _secret_key=self._secret) for row in rows) + return await self.get_by_id(id) - def find(self, repo: str, room_id: RoomID) -> Optional[WebhookInfo]: - return self._execute_select(self._table.c.repo == repo, self._table.c.room_id == room_id) + async def get_by_id(self, id: UUID) -> WebhookInfo | None: + return self._add_to_cache(await self._db.get_webhook_by_id(id)) - def __delitem__(self, key: UUID) -> None: - self.delete(key) + async def get_by_repo(self, room_id: RoomID, repo: str) -> WebhookInfo | None: + return self._add_to_cache(await self._db.get_webhook_by_repo(room_id, repo)) - def __getitem__(self, item: Union[str, UUID]) -> WebhookInfo: - if not isinstance(item, UUID): - item = UUID(item) - value = self.get(item) - if not value: - raise KeyError(item) - return value + async def get_all_for_room(self, room_id: RoomID) -> list[WebhookInfo]: + items = await self._db.get_webhooks_in_room(room_id) + for item in items: + self._webhooks[item.id] = item + return items diff --git a/maubot.yaml b/maubot.yaml index 81e81df..9e8ae39 100644 --- a/maubot.yaml +++ b/maubot.yaml @@ -8,5 +8,6 @@ main_class: GitHubBot extra_files: - base-config.yaml database: true +database_type: asyncpg webapp: true config: true From 196093a127cb21d466074d97f94e11c8bfbafcf5 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 May 2025 16:33:54 +0300 Subject: [PATCH 40/43] Blacken and isort code --- .github/workflows/python-lint.yml | 25 +++++ .pre-commit-config.yaml | 20 ++++ github/api/client.py | 148 +++++++++++++++++++---------- github/api/types.py | 14 +-- github/api/webhook.py | 25 ++--- github/avatar_manager.py | 10 +- github/bot.py | 26 +++--- github/client_manager.py | 29 +++--- github/commands.py | 149 ++++++++++++++++++------------ github/config.py | 10 +- github/db.py | 67 +++++++++----- github/migrations.py | 14 ++- github/template/loader.py | 9 +- github/template/manager.py | 7 +- github/template/proxy.py | 2 +- github/template/util.py | 19 ++-- github/util/__init__.py | 2 +- github/util/contrast.py | 4 +- github/webhook/__init__.py | 2 +- github/webhook/aggregation.py | 70 +++++++++----- github/webhook/handler.py | 26 +++--- github/webhook/manager.py | 14 +-- pyproject.toml | 11 +++ 23 files changed, 465 insertions(+), 238 deletions(-) create mode 100644 .github/workflows/python-lint.yml create mode 100644 .pre-commit-config.yaml create mode 100644 pyproject.toml diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml new file mode 100644 index 0000000..18be560 --- /dev/null +++ b/.github/workflows/python-lint.yml @@ -0,0 +1,25 @@ +name: Python lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-python@v3 + with: + python-version: "3.13" + - uses: isort/isort-action@master + with: + sortPaths: "./rss" + - uses: psf/black@stable + with: + src: "./rss" + - name: pre-commit + run: | + pip install pre-commit + pre-commit run -av trailing-whitespace + pre-commit run -av end-of-file-fixer + pre-commit run -av check-yaml + pre-commit run -av check-added-large-files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7e87c21 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,20 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + exclude_types: [markdown] + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - repo: https://github.com/psf/black + rev: 25.1.0 + hooks: + - id: black + language_version: python3 + files: ^github/.*\.pyi?$ + - repo: https://github.com/PyCQA/isort + rev: 6.0.0 + hooks: + - id: isort + files: ^github/.*\.pyi?$ diff --git a/github/api/client.py b/github/api/client.py index 115a648..2b91952 100644 --- a/github/api/client.py +++ b/github/api/client.py @@ -13,10 +13,10 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Dict, Union, Any, Awaitable, List +from typing import Any, Awaitable, Dict, List, Optional, Union +import json import random import string -import json from aiohttp import ClientSession from yarl import URL @@ -57,7 +57,9 @@ class GitHubClient: token: str _login_state: str - def __init__(self, http: ClientSession, client_id: str, client_secret: str, token: str) -> None: + def __init__( + self, http: ClientSession, client_id: str, client_secret: str, token: str + ) -> None: self.http = http self.client_id = client_id self.client_secret = client_secret @@ -66,37 +68,59 @@ def __init__(self, http: ClientSession, client_id: str, client_secret: str, toke def get_login_url(self, redirect_uri: Union[str, URL], scope: str = "user repo") -> URL: self._login_state = "".join(random.choices(string.ascii_lowercase + string.digits, k=64)) - return self.login_url.with_query({ - "client_id": self.client_id, - "redirect_uri": str(redirect_uri), - "scope": scope, - "state": self._login_state, - }) + return self.login_url.with_query( + { + "client_id": self.client_id, + "redirect_uri": str(redirect_uri), + "scope": scope, + "state": self._login_state, + } + ) async def finish_login(self, code: str, state: str) -> None: if state != self._login_state: raise ValueError("Invalid state") - resp = await self.http.post(self.login_finish_url, json={ - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "state": self._login_state, - }, headers={ - "Accept": "application/json", - }) + resp = await self.http.post( + self.login_finish_url, + json={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "state": self._login_state, + }, + headers={ + "Accept": "application/json", + }, + ) data = await resp.json() self.token = data["access_token"] - def query(self, query: str, args: str = "", variables: Optional[Dict] = None, - path: Optional[str] = None) -> Awaitable[Any]: + def query( + self, + query: str, + args: str = "", + variables: Optional[Dict] = None, + path: Optional[str] = None, + ) -> Awaitable[Any]: return self.call("query", query, args, variables, path) - def mutate(self, query: str, args: str = "", variables: Optional[Dict] = None, - path: Optional[str] = None) -> Awaitable[Any]: + def mutate( + self, + query: str, + args: str = "", + variables: Optional[Dict] = None, + path: Optional[str] = None, + ) -> Awaitable[Any]: return self.call("mutation", query, args, variables, path) - async def call(self, query_type: str, query: str, args: str, variables: Optional[Dict] = None, - path: Optional[str] = None) -> Any: + async def call( + self, + query_type: str, + query: str, + args: str, + variables: Optional[Dict] = None, + path: Optional[str] = None, + ) -> Any: full_query = query_type if args: full_query += f" ({args})" @@ -109,8 +133,9 @@ async def call(self, query_type: str, query: str, args: str, variables: Optional try: data = resp["data"] except KeyError: - raise GraphQLError(type="UNKNOWN_ERROR", - message="Unknown error: GitHub didn't return any data") + raise GraphQLError( + type="UNKNOWN_ERROR", message="Unknown error: GitHub didn't return any data" + ) if path: return recursive_get(data, path) return data @@ -130,18 +155,18 @@ def rest_v3_headers(self) -> Dict[str, str]: } async def call_raw(self, query: str, variables: Optional[Dict] = None) -> dict: - resp = await self.http.post(self.api_url, - json={ - "query": query, - "variables": variables or {} - }, - headers=self.headers) + resp = await self.http.post( + self.api_url, json={"query": query, "variables": variables or {}}, headers=self.headers + ) return await resp.json() @property def _token_url(self) -> URL: - return ((self.base_url / "applications" / self.client_id / "token") - .with_user(self.client_id).with_password(self.client_secret)) + return ( + (self.base_url / "applications" / self.client_id / "token") + .with_user(self.client_id) + .with_password(self.client_secret) + ) async def reset_token(self) -> Optional[str]: resp = await self.http.patch(self._token_url, json={"access_token": self.token}) @@ -155,21 +180,33 @@ async def delete_token(self) -> None: await self.http.delete(self._token_url, json={"access_token": self.token}) async def list_webhooks(self, owner: str, repo: str) -> List[Webhook]: - resp = await self.http.get(self.base_url / "repos" / owner / repo / "hooks", - headers=self.rest_v3_headers) + resp = await self.http.get( + self.base_url / "repos" / owner / repo / "hooks", headers=self.rest_v3_headers + ) return [Webhook.deserialize(info) for info in await resp.json()] async def get_webhook(self, owner: str, repo: str, hook_id: int) -> Webhook: - resp = await self.http.get(self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), - headers=self.rest_v3_headers) + resp = await self.http.get( + self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), + headers=self.rest_v3_headers, + ) data = await resp.json() if resp.status != 200: raise GitHubError(status_code=resp.status, **data) return Webhook.deserialize(data) - async def create_webhook(self, owner: str, repo: str, url: URL, *, active: bool = True, - events: OptStrList = None, content_type: str = "form", - secret: Optional[str] = None, insecure_ssl: bool = False) -> Webhook: + async def create_webhook( + self, + owner: str, + repo: str, + url: URL, + *, + active: bool = True, + events: OptStrList = None, + content_type: str = "form", + secret: Optional[str] = None, + insecure_ssl: bool = False, + ) -> Webhook: payload = { "name": "web", "config": { @@ -181,18 +218,31 @@ async def create_webhook(self, owner: str, repo: str, url: URL, *, active: bool "events": events or ["push"], "active": active, } - resp = await self.http.post(self.base_url / "repos" / owner / repo / "hooks", - data=json.dumps(payload), headers=self.rest_v3_headers) + resp = await self.http.post( + self.base_url / "repos" / owner / repo / "hooks", + data=json.dumps(payload), + headers=self.rest_v3_headers, + ) data = await resp.json() if resp.status != 201: raise GitHubError(status_code=resp.status, **data) return Webhook.deserialize(data) - async def edit_webhook(self, owner: str, repo: str, hook_id: int, *, url: Optional[URL] = None, - active: Optional[bool] = None, events: OptStrList = None, - add_events: OptStrList = None, remove_events: OptStrList = None, - content_type: Optional[str] = None, secret: Optional[str] = None, - insecure_ssl: Optional[bool] = None) -> Webhook: + async def edit_webhook( + self, + owner: str, + repo: str, + hook_id: int, + *, + url: Optional[URL] = None, + active: Optional[bool] = None, + events: OptStrList = None, + add_events: OptStrList = None, + remove_events: OptStrList = None, + content_type: Optional[str] = None, + secret: Optional[str] = None, + insecure_ssl: Optional[bool] = None, + ) -> Webhook: payload = {} if events: if add_events or remove_events: @@ -216,7 +266,9 @@ async def edit_webhook(self, owner: str, repo: str, hook_id: int, *, url: Option payload["config"] = config resp = await self.http.patch( self.base_url / "repos" / owner / repo / "hooks" / str(hook_id), - data=json.dumps(payload), headers=self.rest_v3_headers) + data=json.dumps(payload), + headers=self.rest_v3_headers, + ) data = await resp.json() if resp.status != 200: raise GitHubError(status_code=resp.status, **data) diff --git a/github/api/types.py b/github/api/types.py index 60bf41a..1d169e7 100644 --- a/github/api/types.py +++ b/github/api/types.py @@ -13,13 +13,13 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, NewType, List, Union, Type, Dict, Any +from typing import Any, Dict, List, NewType, Optional, Type, Union from datetime import datetime from attr import dataclass import attr -from mautrix.types import SerializableAttrs, SerializableEnum, serializer, deserializer, JSON +from mautrix.types import JSON, SerializableAttrs, SerializableEnum, deserializer, serializer HubDateTime = NewType("HubDateTime", datetime) ISO_FORMAT = "%Y-%m-%dT%H:%M:%S%z" @@ -471,7 +471,7 @@ class IssueComment(SerializableAttrs): def meta(self) -> Dict[str, Any]: return { - "id": self.id , + "id": self.id, "node_id": self.node_id, } @@ -562,7 +562,7 @@ class DeleteEvent(SerializableAttrs): class MetaAction(SerializableEnum): - DELETED = 'deleted' + DELETED = "deleted" @dataclass @@ -594,7 +594,7 @@ class CommitComment(SerializableAttrs): def meta(self) -> Dict[str, Any]: return { - "id": self.id , + "id": self.id, "node_id": self.node_id, "commit_id": self.commit_id, } @@ -769,7 +769,7 @@ class PartialPullRequest(SerializableAttrs): def meta(self) -> Dict[str, Any]: return { - "id": self.id , + "id": self.id, "node_id": self.node_id, "number": self.number, } @@ -918,7 +918,7 @@ class ReviewComment(SerializableAttrs): def meta(self) -> Dict[str, Any]: return { - "id": self.id , + "id": self.id, "node_id": self.node_id, "pull_request_review_id": self.pull_request_review_id, "commit_id": self.commit_id, diff --git a/github/api/webhook.py b/github/api/webhook.py index 948d0ce..63f1864 100644 --- a/github/api/webhook.py +++ b/github/api/webhook.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Optional, Protocol, TYPE_CHECKING +from typing import TYPE_CHECKING, Optional, Protocol from uuid import UUID import hashlib import hmac @@ -22,23 +22,23 @@ from aiohttp import web from attr import dataclass -from mautrix.types import SerializerError, RoomID from maubot.handlers import web as web_handler +from mautrix.types import RoomID, SerializerError -from .types import EventType, Event, EVENT_CLASSES from ..db import WebhookInfo +from .types import EVENT_CLASSES, Event, EventType if TYPE_CHECKING: from ..webhook import WebhookManager class HandlerFunc(Protocol): - async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, - info: WebhookInfo) -> None: + async def __call__( + self, evt_type: EventType, evt: Event, delivery_id: str, info: WebhookInfo + ) -> None: pass - @dataclass(frozen=True) class GlobalWebhookInfo(WebhookInfo): room_id: RoomID @@ -54,12 +54,13 @@ def __str__(self) -> str: class GitHubWebhookReceiver: - handler: 'HandlerFunc' + handler: "HandlerFunc" secrets: "WebhookManager" global_secret: Optional[str] - def __init__(self, handler: 'HandlerFunc', secrets: "WebhookManager", - global_secret: Optional[str]) -> None: + def __init__( + self, handler: "HandlerFunc", secrets: "WebhookManager", global_secret: Optional[str] + ) -> None: self.handler = handler self.secrets = secrets self.global_secret = global_secret @@ -72,7 +73,9 @@ async def handle_global(self, request: web.Request) -> web.Response: room_id = RoomID(request.query["room"]) except KeyError: return web.Response(status=400, text="room query param missing") - return await self._handle(request, GlobalWebhookInfo(room_id=room_id, secret=self.global_secret)) + return await self._handle( + request, GlobalWebhookInfo(room_id=room_id, secret=self.global_secret) + ) @web_handler.post("/webhook/{id}") async def handle(self, request: web.Request) -> web.Response: @@ -85,7 +88,7 @@ async def handle(self, request: web.Request) -> web.Response: return web.Response(status=404, text="Webhook not found") return await self._handle(request, webhook_info) - async def _handle(self, request: web.Request, webhook_info: 'WebhookInfo') -> web.Response: + async def _handle(self, request: web.Request, webhook_info: "WebhookInfo") -> web.Response: try: signature = request.headers["X-Hub-Signature"] event_type = EventType(request.headers["X-Github-Event"]) diff --git a/github/avatar_manager.py b/github/avatar_manager.py index fd35985..1c28949 100644 --- a/github/avatar_manager.py +++ b/github/avatar_manager.py @@ -1,7 +1,7 @@ from typing import TYPE_CHECKING import asyncio -from sqlalchemy import MetaData, Table, Column, Text +from sqlalchemy import Column, MetaData, Table, Text from sqlalchemy.engine.base import Engine from mautrix.types import ContentURI @@ -13,21 +13,19 @@ class AvatarManager: - bot: 'GitHubBot' + bot: "GitHubBot" _avatars: dict[str, ContentURI] _db: DBManager _lock: asyncio.Lock - def __init__(self, bot: 'GitHubBot') -> None: + def __init__(self, bot: "GitHubBot") -> None: self.bot = bot self._db = bot.db self._lock = asyncio.Lock() self._avatars = {} async def load_db(self) -> None: - self._avatars = {url: ContentURI(mxc) - for url, mxc - in await self._db.get_avatars()} + self._avatars = {url: ContentURI(mxc) for url, mxc in await self._db.get_avatars()} async def get_mxc(self, url: str) -> ContentURI: try: diff --git a/github/bot.py b/github/bot.py index 2606034..d443469 100644 --- a/github/bot.py +++ b/github/bot.py @@ -15,19 +15,18 @@ # along with this program. If not, see . from typing import Type -from mautrix.util.async_db import UpgradeTable - from maubot import Plugin from mautrix.util import background_task +from mautrix.util.async_db import UpgradeTable -from .db import DBManager -from .migrations import upgrade_table -from .webhook import WebhookManager, WebhookHandler -from .client_manager import ClientManager from .api import GitHubWebhookReceiver +from .avatar_manager import AvatarManager +from .client_manager import ClientManager from .commands import Commands from .config import Config -from .avatar_manager import AvatarManager +from .db import DBManager +from .migrations import upgrade_table +from .webhook import WebhookHandler, WebhookManager class GitHubBot(Plugin): @@ -48,14 +47,17 @@ async def start(self) -> None: self.log.info("Running database post-migration") async with self.database.acquire() as conn, conn.transaction(): await self.db.run_post_migration(conn, self.config["webhook_key"]) - self.clients = ClientManager(self.config["client_id"], self.config["client_secret"], - self.http, self.db) + self.clients = ClientManager( + self.config["client_id"], self.config["client_secret"], self.http, self.db + ) self.webhook_manager = WebhookManager(self.db) self.webhook_handler = WebhookHandler(bot=self) self.avatars = AvatarManager(bot=self) - self.webhook_receiver = GitHubWebhookReceiver(handler=self.webhook_handler, - secrets=self.webhook_manager, - global_secret=self.config["global_webhook_secret"]) + self.webhook_receiver = GitHubWebhookReceiver( + handler=self.webhook_handler, + secrets=self.webhook_manager, + global_secret=self.config["global_webhook_secret"], + ) self.commands = Commands(bot=self) await self.clients.load_db() diff --git a/github/client_manager.py b/github/client_manager.py index 0af4835..f0efc20 100644 --- a/github/client_manager.py +++ b/github/client_manager.py @@ -13,10 +13,10 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from aiohttp import web, ClientError, ClientSession +from aiohttp import ClientError, ClientSession, web -from mautrix.types import UserID from maubot.handlers import web as web_handler +from mautrix.types import UserID from .api import GitHubClient from .db import DBManager @@ -37,15 +37,17 @@ def __init__(self, client_id: str, client_secret: str, http: ClientSession, db: self._clients = {} async def load_db(self) -> None: - self._clients = {user_id: self._make(token) - for user_id, token - in await self._db.get_clients()} + self._clients = { + user_id: self._make(token) for user_id, token in await self._db.get_clients() + } def _make(self, token: str) -> GitHubClient: - return GitHubClient(http=self._http, - client_id=self.client_id, - client_secret=self.client_secret, - token=token) + return GitHubClient( + http=self._http, + client_id=self.client_id, + client_secret=self.client_secret, + token=token, + ) async def put(self, user_id: UserID, token: str) -> None: await self._db.put_client(user_id, token) @@ -77,9 +79,12 @@ async def login_callback(self, request: web.Request) -> web.Response: except KeyError: pass else: - return web.Response(status=400, text=f"Failed to log in: {error_code}\n\n" - f"{error_msg}\n\n" - f"More info at {error_uri}") + return web.Response( + status=400, + text=f"Failed to log in: {error_code}\n\n" + f"{error_msg}\n\n" + f"More info at {error_uri}", + ) try: user_id = UserID(request.query["user_id"]) code = request.query["code"] diff --git a/github/commands.py b/github/commands.py index 6442a78..bfb2b9d 100644 --- a/github/commands.py +++ b/github/commands.py @@ -13,12 +13,12 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Optional, Set, Dict, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Tuple import json from maubot import MessageEvent from maubot.handlers import command, event -from mautrix.types import EventType, Event, ReactionEvent, RelationType +from mautrix.types import Event, EventType, ReactionEvent, RelationType from .api import GitHubClient, GitHubError, GraphQLError @@ -28,7 +28,7 @@ def authenticated(_outer_fn=None, *, required: bool = True, error: bool = True): def decorator(fn): - async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: + async def wrapper(self: "Commands", evt: Event, **kwargs) -> None: client = self.bot.clients.get(evt.sender) if required and (not client or not client.token): if error and hasattr(evt, "reply"): @@ -41,8 +41,10 @@ async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: except GraphQLError as e: if error: if e.type == "INSUFFICIENT_SCOPES": - await evt.reply("Your login doesn't have sufficient access to do that. " - "Try adding more permissions with `!github login`.") + await evt.reply( + "Your login doesn't have sufficient access to do that. " + "Try adding more permissions with `!github login`." + ) else: await evt.reply(str(e)) @@ -65,7 +67,7 @@ async def get_relation_target(evt: Event, expected_type: RelationType) -> Option def with_webhook_meta(relation_type: RelationType): def decorator(fn): - async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: + async def wrapper(self: "Commands", evt: Event, **kwargs) -> None: webhook_meta = await get_relation_target(evt, relation_type) if not webhook_meta: return @@ -80,12 +82,12 @@ async def wrapper(self: 'Commands', evt: Event, **kwargs) -> None: class Commands: - bot: 'GitHubBot' + bot: "GitHubBot" _command_prefix: str _aliases: Set[str] - def __init__(self, bot: 'GitHubBot') -> None: + def __init__(self, bot: "GitHubBot") -> None: self.bot = bot self.reload_config() @@ -101,9 +103,11 @@ def reload_config(self) -> None: self._command_prefix = "github" self._aliases = {"github", "gh"} - @command.new(name=lambda self: self._command_prefix, - aliases=lambda self, alias: alias in self._aliases, - require_subcommand=True) + @command.new( + name=lambda self: self._command_prefix, + aliases=lambda self, alias: alias in self._aliases, + require_subcommand=True, + ) async def github(self, evt: MessageEvent) -> None: pass @@ -120,12 +124,17 @@ async def login(self, evt: MessageEvent, flags: str, client: Optional[GitHubClie scopes.remove("admin:repo_hook") if "--private" in flags: scopes.append("repo") - login_url = str(self.bot.clients.get(evt.sender, create=True).get_login_url( - redirect_uri=redirect_url, scope=" ".join(scopes))) + login_url = str( + self.bot.clients.get(evt.sender, create=True).get_login_url( + redirect_uri=redirect_url, scope=" ".join(scopes) + ) + ) if client: username = await client.query("viewer { login }", path="viewer.login") - await evt.reply(f"You're already logged in as @{username}, but you can " - f"[click here to switch to a different account]({login_url})") + await evt.reply( + f"You're already logged in as @{username}, but you can " + f"[click here to switch to a different account]({login_url})" + ) else: await evt.reply(f"[Click here to log in]({login_url})") @@ -139,27 +148,33 @@ async def logout(self, evt: MessageEvent, client: GitHubClient) -> None: @event.on(EventType.ROOM_MESSAGE) @authenticated(error=False) @with_webhook_meta(RelationType.REFERENCE) - async def handle_message(self, evt: MessageEvent, client: GitHubClient, - webhook_meta: Dict[str, Any]) -> None: + async def handle_message( + self, evt: MessageEvent, client: GitHubClient, webhook_meta: Dict[str, Any] + ) -> None: try: full_action = (webhook_meta["event_type"], webhook_meta["action"]) except KeyError: return commentable_actions = (("issues", "opened"), ("issue_comment", "created")) if full_action in commentable_actions: - await client.mutate(query="addComment(input: $input) { clientMutationId }", - args="$input: AddCommentInput!", - variables={"input": { - "subjectId": webhook_meta["issue"]["node_id"], - "body": evt.content.body, - }}) + await client.mutate( + query="addComment(input: $input) { clientMutationId }", + args="$input: AddCommentInput!", + variables={ + "input": { + "subjectId": webhook_meta["issue"]["node_id"], + "body": evt.content.body, + } + }, + ) # We don't need a confirmation here since there must be a webhook. @event.on(EventType.REACTION) @authenticated(error=False) @with_webhook_meta(RelationType.ANNOTATION) - async def handle_reaction(self, evt: ReactionEvent, client: GitHubClient, - webhook_meta: Dict[str, Any]) -> None: + async def handle_reaction( + self, evt: ReactionEvent, client: GitHubClient, webhook_meta: Dict[str, Any] + ) -> None: reaction_map = { "👍": "THUMBS_UP", "👎": "THUMBS_DOWN", @@ -182,9 +197,11 @@ async def handle_reaction(self, evt: ReactionEvent, client: GitHubClient, subject_id = webhook_meta["comment"]["node_id"] else: return - await client.mutate(query="addReaction(input: $input) { clientMutationId }", - args="$input: AddReactionInput!", - variables={"input": {"content": reaction, "subjectId": subject_id}}) + await client.mutate( + query="addReaction(input: $input) { clientMutationId }", + args="$input: AddReactionInput!", + variables={"input": {"content": reaction, "subjectId": subject_id}}, + ) @github.subcommand("ping", help="Check your login status.") @authenticated @@ -205,52 +222,64 @@ async def raw_query(self, evt: MessageEvent, query: str, client: GitHubClient) - await evt.reply(f"Failed to parse variables: {err}") return resp = await client.call_raw(query, variables) - await evt.reply("
    "
    -                        f"{json.dumps(resp, indent=2)}"
    -                        "
    ", allow_html=True) + await evt.reply( + "
    " f"{json.dumps(resp, indent=2)}" "
    ", + allow_html=True, + ) @github.subcommand("create", help="Create an issue. Title on first line, body on other lines") @command.argument("repo", required=False, matches=repo_syntax, label="owner/repo") @command.argument("data", required=True, pass_raw=True, label="title and body") @authenticated - async def create_issue(self, evt: MessageEvent, repo: Tuple[str, str], data: str, - client: GitHubClient) -> None: + async def create_issue( + self, evt: MessageEvent, repo: Tuple[str, str], data: str, client: GitHubClient + ) -> None: title, body = data.split("\n", 1) if "\n" in data else (data, "") if not repo: # TODO support setting default repo await evt.reply("This room does not have a default repo") return - repo_id = await client.query(query="repository (name: $name, owner: $owner) { id }", - args="$owner: String!, $name: String!", - variables={"owner": repo[0], "name": repo[1]}, - path="repository.id") - issue = await client.mutate(query="createIssue(input: $input) { issue { number url } }", - args="$input: CreateIssueInput!", - variables={"input": { - "repositoryId": repo_id, - "title": title, - "body": body, - }}, - path="createIssue.issue") + repo_id = await client.query( + query="repository (name: $name, owner: $owner) { id }", + args="$owner: String!, $name: String!", + variables={"owner": repo[0], "name": repo[1]}, + path="repository.id", + ) + issue = await client.mutate( + query="createIssue(input: $input) { issue { number url } }", + args="$input: CreateIssueInput!", + variables={ + "input": { + "repositoryId": repo_id, + "title": title, + "body": body, + } + }, + path="createIssue.issue", + ) await evt.reply(f"Created [issue #{issue['number']}]({issue['url']})") @github.subcommand("webhook", aliases=["w"], help="Manage webhooks.", required_subcommand=True) - async def webhook(self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient) -> None: + async def webhook( + self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient + ) -> None: pass @webhook.subcommand("list", aliases=["ls", "l"], help="List webhooks in this room.") async def webhook_list(self, evt: MessageEvent) -> None: hooks = await self.bot.webhook_manager.get_all_for_room(evt.room_id) - info = "\n".join(f"* `{hook.repo}` added by " - f"[{hook.user_id}](https://matrix.to/#/{hook.user_id})" - for hook in hooks) + info = "\n".join( + f"* `{hook.repo}` added by " f"[{hook.user_id}](https://matrix.to/#/{hook.user_id})" + for hook in hooks + ) await evt.reply(f"GitHub webhooks in this room:\n\n{info}") @webhook.subcommand("add", aliases=["a", "create", "c"], help="Add a webhook for this room.") @command.argument("repo", required=True, matches=repo_syntax, label="owner/repo") @authenticated - async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient - ) -> None: + async def webhook_create( + self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient + ) -> None: repo_name = f"{repo[0]}/{repo[1]}" existing = await self.bot.webhook_manager.get_by_repo(repo_name, evt.room_id) if existing: @@ -260,7 +289,8 @@ async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: webhook = await self.bot.webhook_manager.create(repo_name, evt.sender, evt.room_id) try: await client.create_webhook( - *repo, url=self.bot.webapp_url / "webhook" / str(webhook.id), + *repo, + url=self.bot.webapp_url / "webhook" / str(webhook.id), secret=webhook.secret, content_type="json", events=["*"], @@ -274,8 +304,9 @@ async def webhook_create(self, evt: MessageEvent, repo: Tuple[str, str], client: @webhook.subcommand("remove", aliases=["delete", "rm", "del"]) @command.argument("repo", required=True, matches=repo_syntax, label="owner/repo") @authenticated(required=False) - async def webhook_remove(self, evt: MessageEvent, repo: Tuple[str, str], - client: Optional[GitHubClient]) -> None: + async def webhook_remove( + self, evt: MessageEvent, repo: Tuple[str, str], client: Optional[GitHubClient] + ) -> None: repo_name = f"{repo[0]}/{repo[1]}" webhook_info = await self.bot.webhook_manager.get_by_repo(repo_name, evt.room_id) if not webhook_info: @@ -291,8 +322,9 @@ async def webhook_remove(self, evt: MessageEvent, repo: Tuple[str, str], await evt.reply("Webhook deleted successfully") return else: - self.bot.log.warning(f"Failed to remove {webhook_info} from GitHub", - exc_info=True) + self.bot.log.warning( + f"Failed to remove {webhook_info} from GitHub", exc_info=True + ) else: await evt.reply("Webhook deleted successfully") return @@ -303,6 +335,7 @@ async def webhook_remove(self, evt: MessageEvent, repo: Tuple[str, str], @webhook.subcommand("inspect", aliases=["i"]) @command.argument("repo", required=True, matches=repo_syntax, label="owner/repo") @authenticated(required=False) - async def webhook_inspect(self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient - ) -> None: + async def webhook_inspect( + self, evt: MessageEvent, repo: Tuple[str, str], client: GitHubClient + ) -> None: pass diff --git a/github/config.py b/github/config.py index edb3d40..ecf9aa0 100644 --- a/github/config.py +++ b/github/config.py @@ -13,8 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import string import random +import string from mautrix.util.config import BaseProxyConfig, ConfigUpdateHelper @@ -25,9 +25,11 @@ class Config(BaseProxyConfig): def do_update(self, helper: ConfigUpdateHelper) -> None: helper.copy("client_id") helper.copy("client_secret") - helper.base["webhook_key"] = ("".join(random.choices(secret_charset, k=64)) - if helper.source.get("webhook_key", "generate") == "generate" - else helper.source["webhook_key"]) + helper.base["webhook_key"] = ( + "".join(random.choices(secret_charset, k=64)) + if helper.source.get("webhook_key", "generate") == "generate" + else helper.source["webhook_key"] + ) helper.copy("global_webhook_secret") helper.copy("reset_tokens") helper.copy("command_options.prefix") diff --git a/github/db.py b/github/db.py index 2fed17b..2bd1c69 100644 --- a/github/db.py +++ b/github/db.py @@ -14,16 +14,16 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -import uuid -import hmac -import hashlib from typing import Optional +import hashlib +import hmac +import uuid from asyncpg import Record from attr import dataclass -from mautrix.types import UserID, EventID, RoomID, ContentURI -from mautrix.util.async_db import Database, Connection +from mautrix.types import ContentURI, EventID, RoomID, UserID +from mautrix.util.async_db import Connection, Database @dataclass(frozen=True) @@ -59,6 +59,7 @@ def from_row(cls, row: Record | None) -> Optional["Avatar"]: mxc=mxc, ) + @dataclass(frozen=True) class WebhookInfo: id: uuid.UUID @@ -88,8 +89,10 @@ def from_row(cls, row: Record | None) -> Optional["WebhookInfo"]: ) def __str__(self) -> str: - return (f"webhook {self.id!s} (GH{self.github_id}) from {self.repo} to {self.room_id}" - f" added by {self.user_id}") + return ( + f"webhook {self.id!s} (GH{self.github_id}) from {self.repo} to {self.room_id}" + f" added by {self.user_id}" + ) class DBManager: @@ -101,18 +104,24 @@ def __init__(self, db: Database) -> None: async def get_event(self, message_id: str, room_id: RoomID) -> EventID | None: return await self.db.fetchval( "SELECT event_id FROM matrix_message WHERE message_id = $1 AND room_id = $2", - message_id, room_id, + message_id, + room_id, ) async def put_event( - self, message_id: str, room_id: RoomID, event_id: EventID, + self, + message_id: str, + room_id: RoomID, + event_id: EventID, ) -> None: await self.db.execute( """ INSERT INTO matrix_message (message_id, room_id, event_id) VALUES ($1, $2, $3) ON CONFLICT (message_id, room_id) DO UPDATE SET event_id = excluded.event_id """, - message_id, room_id, event_id, + message_id, + room_id, + event_id, ) async def get_clients(self) -> list[Client]: @@ -125,12 +134,14 @@ async def put_client(self, user_id: UserID, token: str) -> None: INSERT INTO client (user_id, token) VALUES ($1, $2) ON CONFLICT (user_id) DO UPDATE SET token = excluded.token """, - user_id, token, + user_id, + token, ) async def delete_client(self, user_id: UserID) -> None: await self.db.execute( - "DELETE FROM client WHERE user_id = $1", user_id, + "DELETE FROM client WHERE user_id = $1", + user_id, ) async def get_avatars(self) -> list[Avatar]: @@ -143,33 +154,41 @@ async def put_avatar(self, url: str, mxc: ContentURI) -> None: INSERT INTO avatar (url, mxc) VALUES ($1, $2) ON CONFLICT (url) DO NOTHING """, - url, mxc, + url, + mxc, ) async def get_webhook_by_id(self, id: uuid.UUID) -> WebhookInfo | None: row = await self.db.fetchrow( - "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE id = $1", id, + "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE id = $1", + id, ) return WebhookInfo.from_row(row) async def get_webhook_by_repo(self, room_id: RoomID, repo: str) -> WebhookInfo | None: row = await self.db.fetchrow( - "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE room_id = $1 AND repo = $2", room_id, repo, + "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE room_id = $1 AND repo = $2", + room_id, + repo, ) return WebhookInfo.from_row(row) async def get_webhooks_in_room(self, room_id: RoomID) -> list[WebhookInfo]: rows = await self.db.fetch( - "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE room_id = $1", room_id, + "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE room_id = $1", + room_id, ) return [WebhookInfo.from_row(row) for row in rows] async def delete_webhook(self, id: uuid.UUID) -> None: await self.db.execute( - "DELETE FROM webhook WHERE id = $1", id, + "DELETE FROM webhook WHERE id = $1", + id, ) - async def insert_webhook(self, webhook: WebhookInfo, *, _conn: Connection | None = None) -> None: + async def insert_webhook( + self, webhook: WebhookInfo, *, _conn: Connection | None = None + ) -> None: await (_conn or self.db).execute( """ INSERT INTO webhook (id, repo, user_id, room_id, secret, github_id) @@ -180,17 +199,23 @@ async def insert_webhook(self, webhook: WebhookInfo, *, _conn: Connection | None async def set_webhook_github_id(self, id: uuid.UUID, github_id: int) -> None: await self.db.execute( - "UPDATE webhook SET github_id = $1 WHERE id = $2", github_id, id, + "UPDATE webhook SET github_id = $1 WHERE id = $2", + github_id, + id, ) async def transfer_webhook_repo(self, id: uuid.UUID, new_repo: str) -> None: await self.db.execute( - "UPDATE webhook SET repo = $1 WHERE id = $2", new_repo, id, + "UPDATE webhook SET repo = $1 WHERE id = $2", + new_repo, + id, ) async def transfer_webhook_rooms(self, old_room: RoomID, new_room: RoomID) -> None: await self.db.execute( - "UPDATE webhook SET room_id = $1 WHERE room_id = $2 ON CONFLICT (repo, room_id) DO NOTHING", new_room, old_room, + "UPDATE webhook SET room_id = $1 WHERE room_id = $2 ON CONFLICT (repo, room_id) DO NOTHING", + new_room, + old_room, ) async def run_post_migration(self, conn: Connection, secret_key: str) -> None: diff --git a/github/migrations.py b/github/migrations.py index 9c7c1f0..11719c6 100644 --- a/github/migrations.py +++ b/github/migrations.py @@ -17,16 +17,19 @@ upgrade_table = UpgradeTable() + @upgrade_table.register(description="Latest revision", upgrades_to=1) async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: needs_migration = False if await conn.table_exists("webhook"): needs_migration = True - await conn.execute(""" + await conn.execute( + """ ALTER TABLE webhook RENAME TO webhook_old; ALTER TABLE client RENAME TO client_old; ALTER TABLE matrix_message RENAME TO matrix_message_old; - """) + """ + ) await conn.execute( f"""CREATE TABLE client ( user_id TEXT NOT NULL, @@ -42,7 +45,7 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: room_id TEXT NOT NULL, secret TEXT NOT NULL, github_id INTEGER, - PRIMARY KEY (id), + PRIMARY KEY (id), CONSTRAINT webhook_repo_room_unique UNIQUE (repo, room_id) )""" ) @@ -64,7 +67,10 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: if needs_migration: await migrate_legacy_to_v1(conn) + async def migrate_legacy_to_v1(conn: Connection) -> None: await conn.execute("INSERT INTO client (user_id, token) SELECT user_id, token FROM client_old") - await conn.execute("INSERT INTO matrix_message (message_id, room_id, event_id) SELECT message_id, room_id, event_id FROM matrix_message_old") + await conn.execute( + "INSERT INTO matrix_message (message_id, room_id, event_id) SELECT message_id, room_id, event_id FROM matrix_message_old" + ) await conn.execute("CREATE TABLE needs_post_migration(noop INTEGER PRIMARY KEY)") diff --git a/github/template/loader.py b/github/template/loader.py index ec2899f..7d07089 100644 --- a/github/template/loader.py +++ b/github/template/loader.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Tuple, Iterable, Any, Callable +from typing import Any, Callable, Iterable, Tuple from jinja2 import BaseLoader, TemplateNotFound @@ -42,8 +42,11 @@ def get_source(self, environment: Any, name: str) -> Tuple[str, str, Callable[[] raise TemplateNotFound(name) if not tpl: raise TemplateNotFound(name) - return (self.config["macros"] + tpl, name, - lambda: self.reload_counter == cur_reload_counter) + return ( + self.config["macros"] + tpl, + name, + lambda: self.reload_counter == cur_reload_counter, + ) def list_templates(self) -> Iterable[str]: return sorted(self.config[self.field].keys()) diff --git a/github/template/manager.py b/github/template/manager.py index ab33ed9..1b8f5f9 100644 --- a/github/template/manager.py +++ b/github/template/manager.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Any +from typing import Any, Dict from jinja2 import Environment as JinjaEnvironment, Template @@ -30,8 +30,9 @@ class TemplateManager: def __init__(self, config: Config, key: str) -> None: self._loader = ConfigTemplateLoader(config, key) - self._env = JinjaEnvironment(loader=self._loader, lstrip_blocks=True, trim_blocks=True, - extensions=["jinja2.ext.do"]) + self._env = JinjaEnvironment( + loader=self._loader, lstrip_blocks=True, trim_blocks=True, extensions=["jinja2.ext.do"] + ) self._env.filters["markdown"] = lambda message: markdown.render(message, allow_html=True) def __getitem__(self, item: str) -> Template: diff --git a/github/template/proxy.py b/github/template/proxy.py index fc85f75..6b39a00 100644 --- a/github/template/proxy.py +++ b/github/template/proxy.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Any +from typing import Any, Dict from jinja2 import Environment as JinjaEnvironment, TemplateNotFound diff --git a/github/template/util.py b/github/template/util.py index 58f2705..33128eb 100644 --- a/github/template/util.py +++ b/github/template/util.py @@ -13,7 +13,7 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import List, Callable +from typing import Callable, List from ..util import contrast, hex_to_rgb @@ -26,9 +26,11 @@ class TemplateUtil: @classmethod def contrast_fg(cls, color: str) -> str: - return (cls.white_hex - if contrast(hex_to_rgb(color), cls.white_rgb) >= cls.contrast_threshold - else cls.black_hex) + return ( + cls.white_hex + if contrast(hex_to_rgb(color), cls.white_rgb) >= cls.contrast_threshold + else cls.black_hex + ) @staticmethod def cut_message(message: str, max_len: int = 72) -> str: @@ -55,8 +57,13 @@ def ref_name(ref: str) -> str: return ref.split("/", 2)[2] @staticmethod - def join_human_list(data: List[str], *, joiner: str = ", ", final_joiner: str = " and ", - mutate: Callable[[str], str] = lambda val: val) -> str: + def join_human_list( + data: List[str], + *, + joiner: str = ", ", + final_joiner: str = " and ", + mutate: Callable[[str], str] = lambda val: val, + ) -> str: if not data: return "" elif len(data) == 1: diff --git a/github/util/__init__.py b/github/util/__init__.py index b02c41d..d5eba73 100644 --- a/github/util/__init__.py +++ b/github/util/__init__.py @@ -1,2 +1,2 @@ -from .recursive_get import recursive_get from .contrast import contrast, hex_to_rgb, rgb_to_hex +from .recursive_get import recursive_get diff --git a/github/util/contrast.py b/github/util/contrast.py index 354c155..42aad18 100644 --- a/github/util/contrast.py +++ b/github/util/contrast.py @@ -14,8 +14,8 @@ def hex_to_rgb(color: str) -> RGB: step = 1 if len(color) == 3 else 2 try: r = int(color[0:step], 16) - g = int(color[step:2 * step], 16) - b = int(color[2 * step:3 * step], 16) + g = int(color[step : 2 * step], 16) + b = int(color[2 * step : 3 * step], 16) except ValueError as e: raise ValueError("Invalid hex value") from e return r / 255, g / 255, b / 255 diff --git a/github/webhook/__init__.py b/github/webhook/__init__.py index 0ac3076..c54d571 100644 --- a/github/webhook/__init__.py +++ b/github/webhook/__init__.py @@ -1,3 +1,3 @@ -from .handler import WebhookHandler from .aggregation import PendingAggregation +from .handler import WebhookHandler from .manager import WebhookInfo, WebhookManager diff --git a/github/webhook/aggregation.py b/github/webhook/aggregation.py index e737c3c..c6aca67 100644 --- a/github/webhook/aggregation.py +++ b/github/webhook/aggregation.py @@ -13,12 +13,19 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Tuple, Set, Callable, Type, Optional, Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Tuple, Type import asyncio +from ..api.types import ( + ACTION_CLASSES, + Action, + CommentAction, + Event, + EventType, + IssueAction, + PullRequestAction, +) from .manager import WebhookInfo -from ..api.types import (Event, EventType, Action, IssueAction, PullRequestAction, CommentAction, - ACTION_CLASSES) if TYPE_CHECKING: from .handler import WebhookHandler @@ -40,8 +47,9 @@ def start_label_aggregation(self) -> None: self.event.action = self.action_type.X_LABEL_AGGREGATE def start_open_label_dropping(self) -> None: - event_field = (self.event.issue if self.event_type == EventType.ISSUES - else self.event.pull_request) + event_field = ( + self.event.issue if self.event_type == EventType.ISSUES else self.event.pull_request + ) self._label_ids = {label.id for label in event_field.labels} def start_milestone_aggregation(self) -> None: @@ -66,7 +74,7 @@ def start_milestone_aggregation(self) -> None: timeout = 3 - handler: 'WebhookHandler' + handler: "WebhookHandler" webhook_info: WebhookInfo delivery_ids: Set[str] event_type: EventType @@ -77,8 +85,14 @@ def start_milestone_aggregation(self) -> None: _label_ids: Optional[Set[int]] - def __init__(self, handler: 'WebhookHandler', evt_type: EventType, evt: Event, delivery_id: str, - webhook_info: WebhookInfo) -> None: + def __init__( + self, + handler: "WebhookHandler", + evt_type: EventType, + evt: Event, + delivery_id: str, + webhook_info: WebhookInfo, + ) -> None: self.handler = handler self.webhook_info = webhook_info self.event_type = evt_type @@ -131,22 +145,31 @@ async def _send(self) -> None: aggregation=self.aggregation, ) if self.event_type == EventType.PUSH and event_id: - await self.handler.bot.db.put_event(self.event.message_id, self.webhook_info.room_id, event_id) + await self.handler.bot.db.put_event( + self.event.message_id, self.webhook_info.room_id, event_id + ) def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: postpone = True - if (evt_type == EventType.ISSUES and self.event_type == EventType.ISSUE_COMMENT - and evt.action in (IssueAction.CLOSED, IssueAction.REOPENED) - and evt.issue_id == self.event.issue_id - and self.event.sender.id == evt.sender.id): + if ( + evt_type == EventType.ISSUES + and self.event_type == EventType.ISSUE_COMMENT + and evt.action in (IssueAction.CLOSED, IssueAction.REOPENED) + and evt.issue_id == self.event.issue_id + and self.event.sender.id == evt.sender.id + ): if evt.action == IssueAction.CLOSED: self.aggregation["closed"] = True elif evt.action == IssueAction.REOPENED: self.aggregation["reopened"] = True - elif (evt_type == EventType.ISSUE_COMMENT and self.event_type == EventType.ISSUES - and evt.action == CommentAction.CREATED and self.event.sender.id == evt.sender.id - and evt.issue_id == self.event.issue_id - and self.event.action in (IssueAction.CLOSED, IssueAction.REOPENED)): + elif ( + evt_type == EventType.ISSUE_COMMENT + and self.event_type == EventType.ISSUES + and evt.action == CommentAction.CREATED + and self.event.sender.id == evt.sender.id + and evt.issue_id == self.event.issue_id + and self.event.action in (IssueAction.CLOSED, IssueAction.REOPENED) + ): self.event_type = evt_type self.event = evt if evt.action == IssueAction.CLOSED: @@ -155,10 +178,15 @@ def aggregate(self, evt_type: EventType, evt: Event, delivery_id: str) -> bool: self.aggregation["reopened"] = True elif evt_type != self.event_type: return False - elif (self.event_type in (EventType.ISSUES, EventType.PULL_REQUEST) - and evt.issue_id == self.event.issue_id): - if (self.event.action == self.action_type.OPENED - and evt.label and evt.label.id in self._label_ids): + elif ( + self.event_type in (EventType.ISSUES, EventType.PULL_REQUEST) + and evt.issue_id == self.event.issue_id + ): + if ( + self.event.action == self.action_type.OPENED + and evt.label + and evt.label.id in self._label_ids + ): # Label was already in original event, drop the message. pass elif self.event.action == self.action_type.X_LABEL_AGGREGATE: diff --git a/github/webhook/handler.py b/github/webhook/handler.py index a4755ec..90eb01b 100644 --- a/github/webhook/handler.py +++ b/github/webhook/handler.py @@ -13,8 +13,8 @@ # # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . -from typing import Dict, Set, Deque, Optional, Any, TYPE_CHECKING -from collections import deque, defaultdict +from typing import TYPE_CHECKING, Any, Deque, Dict, Optional, Set +from collections import defaultdict, deque from uuid import UUID import asyncio import logging @@ -36,21 +36,21 @@ ) from mautrix.util.formatter import parse_html -from ..template import TemplateManager, TemplateUtil from ..api.types import ( + ACTION_CLASSES, + OTHER_ENUMS, Event, EventType, MetaAction, PushEvent, RepositoryAction, + User, WorkflowJobEvent, expand_enum, - ACTION_CLASSES, - OTHER_ENUMS, - User, ) -from .manager import WebhookInfo +from ..template import TemplateManager, TemplateUtil from .aggregation import PendingAggregation +from .manager import WebhookInfo if TYPE_CHECKING: from ..bot import GitHubBot @@ -61,13 +61,13 @@ class WebhookHandler: log: logging.Logger - bot: 'GitHubBot' + bot: "GitHubBot" msgtype: MessageType messages: TemplateManager templates: TemplateManager pending_aggregations: Dict[UUID, Deque[PendingAggregation]] - def __init__(self, bot: 'GitHubBot') -> None: + def __init__(self, bot: "GitHubBot") -> None: self.bot = bot self.log = self.bot.log.getChild("webhook") self.msgtype = MessageType(bot.config["message_options.msgtype"]) or MessageType.NOTICE @@ -79,10 +79,14 @@ def __init__(self, bot: 'GitHubBot') -> None: def reload_config(self) -> None: self.messages.reload() self.templates.reload() - self.msgtype = MessageType(self.bot.config["message_options.msgtype"]) or MessageType.NOTICE + self.msgtype = ( + MessageType(self.bot.config["message_options.msgtype"]) or MessageType.NOTICE + ) PendingAggregation.timeout = int(self.bot.config["message_options.aggregation_timeout"]) - async def __call__(self, evt_type: EventType, evt: Event, delivery_id: str, info: WebhookInfo) -> None: + async def __call__( + self, evt_type: EventType, evt: Event, delivery_id: str, info: WebhookInfo + ) -> None: if evt_type == EventType.PING: self.log.debug(f"Received ping for {info}: {evt.zen}") await self.bot.webhook_manager.set_github_id(info, evt.hook_id) diff --git a/github/webhook/manager.py b/github/webhook/manager.py index 0aa24d1..f619738 100644 --- a/github/webhook/manager.py +++ b/github/webhook/manager.py @@ -16,7 +16,7 @@ from uuid import UUID, uuid4 import random -from mautrix.types import UserID, RoomID +from mautrix.types import RoomID, UserID from ..db import DBManager, WebhookInfo @@ -30,11 +30,13 @@ def __init__(self, db: DBManager): self._webhooks = {} async def create(self, repo: str, user_id: UserID, room_id: RoomID) -> WebhookInfo: - info = WebhookInfo(id=uuid4(), - repo=repo, - user_id=user_id, - room_id=room_id, - secret=random.randbytes(16).hex()) + info = WebhookInfo( + id=uuid4(), + repo=repo, + user_id=user_id, + room_id=room_id, + secret=random.randbytes(16).hex(), + ) await self._db.insert_webhook(info) self._webhooks[info.id] = info return info diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..f143797 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,11 @@ +[tool.isort] +profile = "black" +force_to_top = "typing" +from_first = true +combine_as_imports = true +known_first_party = ["mautrix", "maubot"] +line_length = 99 + +[tool.black] +line-length = 99 +target-version = ["py310"] From b55f2a5676ed6b422edb8b143213981d56fe00b9 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 May 2025 16:46:40 +0300 Subject: [PATCH 41/43] Handle room tombstones and fix bugs --- base-config.yaml | 2 +- github/avatar_manager.py | 4 +++- github/bot.py | 1 + github/client_manager.py | 2 +- github/db.py | 19 +++++++++++++------ github/migrations.py | 10 +++------- github/webhook/manager.py | 8 +++++++- 7 files changed, 29 insertions(+), 17 deletions(-) diff --git a/base-config.yaml b/base-config.yaml index f913185..5ea2fa1 100644 --- a/base-config.yaml +++ b/base-config.yaml @@ -3,7 +3,7 @@ client_id: 123ab456789cd0123456 client_secret: 12345a6789b0123cd45e6f7g8h9i01234567890j # Random string for signing webhook secrets. If set to "generate", one will be generated for you. -# WARNING: Changing this will invalidate all created webhooks. +# This is no longer used for new webhooks, it is only needed while migrating. webhook_key: generate # A global webhook secret that can be used to manually created webhooks. # This is useful if you're self-hosting, but don't want to set up the GitHub OAuth stuff. diff --git a/github/avatar_manager.py b/github/avatar_manager.py index 1c28949..5ef699e 100644 --- a/github/avatar_manager.py +++ b/github/avatar_manager.py @@ -25,7 +25,9 @@ def __init__(self, bot: "GitHubBot") -> None: self._avatars = {} async def load_db(self) -> None: - self._avatars = {url: ContentURI(mxc) for url, mxc in await self._db.get_avatars()} + self._avatars = { + avatar.url: ContentURI(avatar.mxc) for avatar in await self._db.get_avatars() + } async def get_mxc(self, url: str) -> ContentURI: try: diff --git a/github/bot.py b/github/bot.py index d443469..5b2172b 100644 --- a/github/bot.py +++ b/github/bot.py @@ -47,6 +47,7 @@ async def start(self) -> None: self.log.info("Running database post-migration") async with self.database.acquire() as conn, conn.transaction(): await self.db.run_post_migration(conn, self.config["webhook_key"]) + self.log.info("Webhook secret migration completed successfully") self.clients = ClientManager( self.config["client_id"], self.config["client_secret"], self.http, self.db ) diff --git a/github/client_manager.py b/github/client_manager.py index f0efc20..0e7f2be 100644 --- a/github/client_manager.py +++ b/github/client_manager.py @@ -38,7 +38,7 @@ def __init__(self, client_id: str, client_secret: str, http: ClientSession, db: async def load_db(self) -> None: self._clients = { - user_id: self._make(token) for user_id, token in await self._db.get_clients() + cli.user_id: self._make(cli.token) for cli in await self._db.get_clients() } def _make(self, token: str) -> GitHubClient: diff --git a/github/db.py b/github/db.py index 2bd1c69..cc110cd 100644 --- a/github/db.py +++ b/github/db.py @@ -161,7 +161,7 @@ async def put_avatar(self, url: str, mxc: ContentURI) -> None: async def get_webhook_by_id(self, id: uuid.UUID) -> WebhookInfo | None: row = await self.db.fetchrow( "SELECT id, repo, user_id, room_id, github_id, secret FROM webhook WHERE id = $1", - id, + str(id), ) return WebhookInfo.from_row(row) @@ -183,7 +183,7 @@ async def get_webhooks_in_room(self, room_id: RoomID) -> list[WebhookInfo]: async def delete_webhook(self, id: uuid.UUID) -> None: await self.db.execute( "DELETE FROM webhook WHERE id = $1", - id, + str(id), ) async def insert_webhook( @@ -194,21 +194,26 @@ async def insert_webhook( INSERT INTO webhook (id, repo, user_id, room_id, secret, github_id) VALUES ($1, $2, $3, $4, $5, $6) """, - *webhook, + str(webhook.id), + webhook.repo, + webhook.user_id, + webhook.room_id, + webhook.secret, + webhook.github_id, ) async def set_webhook_github_id(self, id: uuid.UUID, github_id: int) -> None: await self.db.execute( "UPDATE webhook SET github_id = $1 WHERE id = $2", github_id, - id, + str(id), ) async def transfer_webhook_repo(self, id: uuid.UUID, new_repo: str) -> None: await self.db.execute( "UPDATE webhook SET repo = $1 WHERE id = $2", new_repo, - id, + str(id), ) async def transfer_webhook_rooms(self, old_room: RoomID, new_room: RoomID) -> None: @@ -219,7 +224,9 @@ async def transfer_webhook_rooms(self, old_room: RoomID, new_room: RoomID) -> No ) async def run_post_migration(self, conn: Connection, secret_key: str) -> None: - rows = await self.db.fetch("SELECT id, repo, user_id, room_id, github_id FROM webhook_old") + rows = list( + await conn.fetch("SELECT id, repo, user_id, room_id, github_id FROM webhook_old") + ) for row in rows: id = uuid.UUID(row["id"]) secret = hmac.new(key=secret_key.encode("utf-8"), digestmod=hashlib.sha256) diff --git a/github/migrations.py b/github/migrations.py index 11719c6..50ad2c6 100644 --- a/github/migrations.py +++ b/github/migrations.py @@ -23,13 +23,9 @@ async def upgrade_latest(conn: Connection, scheme: Scheme) -> None: needs_migration = False if await conn.table_exists("webhook"): needs_migration = True - await conn.execute( - """ - ALTER TABLE webhook RENAME TO webhook_old; - ALTER TABLE client RENAME TO client_old; - ALTER TABLE matrix_message RENAME TO matrix_message_old; - """ - ) + await conn.execute("ALTER TABLE webhook RENAME TO webhook_old;") + await conn.execute("ALTER TABLE client RENAME TO client_old;") + await conn.execute("ALTER TABLE matrix_message RENAME TO matrix_message_old;") await conn.execute( f"""CREATE TABLE client ( user_id TEXT NOT NULL, diff --git a/github/webhook/manager.py b/github/webhook/manager.py index f619738..dfcda4e 100644 --- a/github/webhook/manager.py +++ b/github/webhook/manager.py @@ -16,7 +16,8 @@ from uuid import UUID, uuid4 import random -from mautrix.types import RoomID, UserID +from maubot.handlers import event +from mautrix.types import EventType, RoomID, StateEvent, UserID from ..db import DBManager, WebhookInfo @@ -80,3 +81,8 @@ async def get_all_for_room(self, room_id: RoomID) -> list[WebhookInfo]: for item in items: self._webhooks[item.id] = item return items + + @event.on(EventType.ROOM_TOMBSTONE) + async def handle_tombstone(self, evt: StateEvent) -> None: + if evt.state_key == "" and evt.content.replacement_room: + await self.transfer_rooms(evt.room_id, evt.content.replacement_room) From 4daaf1a318cf03bbadebb3dcaedfcdaafbf6b355 Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 May 2025 17:00:49 +0300 Subject: [PATCH 42/43] Fix paths in github actions --- .github/workflows/python-lint.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/python-lint.yml b/.github/workflows/python-lint.yml index 18be560..ac49a0f 100644 --- a/.github/workflows/python-lint.yml +++ b/.github/workflows/python-lint.yml @@ -12,10 +12,10 @@ jobs: python-version: "3.13" - uses: isort/isort-action@master with: - sortPaths: "./rss" + sortPaths: "./github" - uses: psf/black@stable with: - src: "./rss" + src: "./github" - name: pre-commit run: | pip install pre-commit From 4ba9cdb27bb573b44971ebe38b7cd8cd6b119f4c Mon Sep 17 00:00:00 2001 From: Tulir Asokan Date: Wed, 7 May 2025 17:04:14 +0300 Subject: [PATCH 43/43] Fix tombstone handling --- github/bot.py | 1 + github/db.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/github/bot.py b/github/bot.py index 5b2172b..69b9bb6 100644 --- a/github/bot.py +++ b/github/bot.py @@ -67,6 +67,7 @@ async def start(self) -> None: self.register_handler_class(self.webhook_receiver) self.register_handler_class(self.clients) self.register_handler_class(self.commands) + self.register_handler_class(self.webhook_manager) async def reset_tokens(self) -> None: try: diff --git a/github/db.py b/github/db.py index cc110cd..6fadb08 100644 --- a/github/db.py +++ b/github/db.py @@ -218,7 +218,7 @@ async def transfer_webhook_repo(self, id: uuid.UUID, new_repo: str) -> None: async def transfer_webhook_rooms(self, old_room: RoomID, new_room: RoomID) -> None: await self.db.execute( - "UPDATE webhook SET room_id = $1 WHERE room_id = $2 ON CONFLICT (repo, room_id) DO NOTHING", + "UPDATE webhook SET room_id = $1 WHERE room_id = $2", new_room, old_room, )