diff --git a/.editorconfig b/.editorconfig index bd22de57..ecced6df 100644 --- a/.editorconfig +++ b/.editorconfig @@ -7,13 +7,13 @@ indent_size = 4 indent_style = space insert_final_newline = false max_line_length = 88 +ij_visual_guides = 88 tab_width = 4 -ij_continuation_indent_size = 8 +ij_continuation_indent_size = 4 ij_formatter_off_tag = @formatter:off ij_formatter_on_tag = @formatter:on ij_formatter_tags_enabled = false ij_smart_tabs = false -ij_visual_guides = 72 ij_wrap_on_typing = false [*.css] @@ -51,7 +51,6 @@ ij_markdown_min_lines_around_header = 1 ij_markdown_min_lines_between_paragraphs = 1 [{*.py,*.pyw}] -ij_visual_guides = none ij_python_align_collections_and_comprehensions = true ij_python_align_multiline_imports = true ij_python_align_multiline_parameters = true diff --git a/.github/scripts/license_generator.py b/.github/scripts/license_generator.py new file mode 100644 index 00000000..703be84f --- /dev/null +++ b/.github/scripts/license_generator.py @@ -0,0 +1,18 @@ +from glob import glob + +pincer_license = """# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +""" + + +for file in glob("./pincer/**/*.py", recursive=True): + if file == "./pincer/__init__.py": + continue + + with open(file, "r+") as f: + lines = f.readlines() + if not lines[0].startswith("# Copyright Pincer 2021-Present\n"): + lines.insert(0, pincer_license) + f.seek(0) + f.writelines(lines) diff --git a/.github/workflows/run_scripts.yml b/.github/workflows/run_scripts.yml new file mode 100644 index 00000000..24b5275b --- /dev/null +++ b/.github/workflows/run_scripts.yml @@ -0,0 +1,34 @@ +name: Run Scripts + +on: push + +jobs: + build: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v2 + + with: + python-version: "3.8" + architecture: "x64" + + - name: setup git + run: | + git config user.name "GitHub Actions" + git config user.email "actions@pincer.dev" + git pull + + - name: running all sorter + run: | + python ./.github/scripts/all_sorter.py + git commit -am ":art: Automatic `__all__` sorting" || echo "No changes to commit" + git push || echo "No changes to push" + + + - name: running license generator + run: | + python ./.github/scripts/license_generator.py + git commit -am ":page_facing_up: Automatic license generator" || echo "No changes to commit" + git push || echo "No changes to push" diff --git a/.github/workflows/sort_alls.yaml b/.github/workflows/sort_alls.yaml deleted file mode 100644 index 404edaae..00000000 --- a/.github/workflows/sort_alls.yaml +++ /dev/null @@ -1,32 +0,0 @@ -name: Sort Alls - -on: push - -jobs: - build: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 - - with: - persist-credentials: false - fetch-depth: 0 - python-version: '3.8' - architecture: 'x64' - - - name: running script - run: python ./.github/scripts/all_sorter.py - - - name: setup git - run: | - git config user.name "GitHub Actions" - git config user.email "actions@pincer.dev" - - - name: add to git - run: | - git pull - git add . - git diff-index --quiet HEAD || git commit -m ":art: Automatic `__all__` sorting" - git push diff --git a/LICENSE b/LICENSE index da17ed70..033af171 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2021 Pincer +Copyright (c) 2021 - 2022 Pincer Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/Pipfile b/Pipfile index 78ac8801..252e395f 100644 --- a/Pipfile +++ b/Pipfile @@ -6,16 +6,16 @@ name = "pypi" [packages] websockets = ">=10.0" aiohttp = ">=3.7.4post0,<4.1.0" -Pillow = "==9.0.0" +Pillow = "==9.0.1" [dev-packages] flake8 = "==4.0.1" tox = "==3.24.5" -pytest = "==6.2.5" -pytest-asyncio = "==0.16.0" +pytest = "==7.0.1" +pytest-asyncio = "==0.18.1" pytest-cov = "==3.0.0" -mypy = "==0.930" -twine = "==3.7.1" +mypy = "==0.931" +twine = "==3.8.0" wheel = "==0.37.1" [requires] diff --git a/VERSION b/VERSION index 0cf69a5c..d183d4ac 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.15.3 \ No newline at end of file +0.16.0 \ No newline at end of file diff --git a/docs/PYPI.md b/docs/PYPI.md index e0686fd8..42939cd1 100644 --- a/docs/PYPI.md +++ b/docs/PYPI.md @@ -114,6 +114,9 @@ client.run() Pincer makes developing application commands intuitive and fast. ```py +from typing import Annotation # python 3.9+ +from typing_extensions import Annotation # python 3.8 + from pincer import Client from pincer.commands import command, CommandArg, Description from pincer.objects import UserMessage, User @@ -139,8 +142,8 @@ class Bot(Client): @command(description="Add two numbers!") async def add( self, - first: CommandArg[int, Description["The first number"]], - second: CommandArg[int, Description["The second number"]] + first: Annotation[int, Description("The first number")], + second: Annotation[int, Description("The second number")] ): return f"The addition of `{first}` and `{second}` is `{first + second}`" diff --git a/docs/README.md b/docs/README.md index 7e5a4027..4a6f167b 100644 --- a/docs/README.md +++ b/docs/README.md @@ -121,6 +121,9 @@ client.run() Pincer makes developing application commands intuitive and fast. ```py +from typing import Annotated # python 3.9+ +from typing_extensions import Annotated # python 3.8 + from pincer import Client from pincer.commands import command, CommandArg, Description from pincer.objects import UserMessage, User @@ -146,8 +149,8 @@ class Bot(Client): @command(description="Add two numbers!") async def add( self, - first: CommandArg[int, Description["The first number"]], - second: CommandArg[int, Description["The second number"]] + first: Annotated[int, Description("The first number")], + second: Annotated[int, Description("The second number")] ): return f"The addition of `{first}` and `{second}` is `{first + second}`" @@ -183,9 +186,7 @@ freedom to create custom events and remove the already existing middleware creat the developers. Your custom middleware directly receives the payload from Discord. You can't do anything wrong without accessing the `override` attribute, but if you do access it, the Pincer team will not provide any support for weird behavior. -So, in short, only use this if you know what you're doing. An example of using -the middleware system with a custom `on_ready` event can be found -[in our docs](https://pincer.readthedocs.io/en/latest/pincer.html#pincer.client.middleware). +So, in short, only use this if you know what you're doing. ._ ## 🏷️ License diff --git a/docs/api/cog.rst b/docs/api/cog.rst new file mode 100644 index 00000000..ad046223 --- /dev/null +++ b/docs/api/cog.rst @@ -0,0 +1,14 @@ + +.. currentmodule:: pincer.cog + +Pincer Cog Module +================= + +cog +--- +.. attributetable:: Cog +.. autoclass:: Cog() + +.. autofunction:: load_cog +.. autofunction:: load_module +.. autofunction:: reload_cog diff --git a/docs/api/commands.rst b/docs/api/commands.rst index 181d47be..2b45051e 100644 --- a/docs/api/commands.rst +++ b/docs/api/commands.rst @@ -27,6 +27,7 @@ Command Types .. autoclass:: MaxValue() .. autoclass:: MinValue() .. autoclass:: ChannelTypes() +.. autoclass:: CommandArg() ChatCommandHandler ------------------ @@ -66,4 +67,11 @@ Command Groups .. currentmodule:: pincer.commands.groups .. autoclass:: Group() -.. autoclass:: Subgroup() \ No newline at end of file +.. autoclass:: Subgroup() + +Interactable Objects +~~~~~~~~~~~~~~~~~~~~ +.. currentmodule:: pincer.commands.interactable + +.. attributetable:: Interactable +.. autoclass:: Interactable() diff --git a/docs/api/index.rst b/docs/api/index.rst index b3890890..415c0fb0 100644 --- a/docs/api/index.rst +++ b/docs/api/index.rst @@ -8,6 +8,7 @@ The Full Pincer API Referance pincer core + cog commands middleware objects/index diff --git a/docs/api/objects/app.rst b/docs/api/objects/app.rst index 508a77de..59ca1572 100644 --- a/docs/api/objects/app.rst +++ b/docs/api/objects/app.rst @@ -51,12 +51,12 @@ AppCommand .. autoclass:: AppCommand() -ClientCommandStructure +InteractableStructure ~~~~~~~~~~~~~~~~~~~~~~ -.. attributetable:: ClientCommandStructure +.. attributetable:: InteractableStructure -.. autoclass:: ClientCommandStructure() +.. autoclass:: InteractableStructure() Intents ------- diff --git a/docs/api/pincer.rst b/docs/api/pincer.rst index 97d51519..22aa5619 100644 --- a/docs/api/pincer.rst +++ b/docs/api/pincer.rst @@ -13,6 +13,7 @@ Client .. autoclass:: Client :exclude-members: event + :inherited-members: .. automethod:: Client.event() :decorator: diff --git a/docs/interactions.rst b/docs/interactions.rst index 7f25b53a..97e6a2de 100644 --- a/docs/interactions.rst +++ b/docs/interactions.rst @@ -128,24 +128,24 @@ The list of possible type hints is as follows: - Mentionable You might want to specify more information for your arguments. If you want a description for your command, you will have to use the -:class:`~pincer.commands.arg_types.Description` type. Modifier types like this need to be inside of the :class:`~pincer.commands.arg_types.CommandArg` +:class:`~pincer.commands.arg_types.Description` type. Modifier types like this need to be inside of the :class:`~typing.Annotated` type. .. code-block:: python - from pincer.commands import CommandArg, Description + from typing import Annotated # Python 3.9+ + from typing_extensions import Annotated # Python 3.8 + + from pincer.commands import Description from pincer.objects import MessageContext @command async def say( self, ctx: MessageContext, - word: CommandArg[ + word: Annotated[ str, - # This will likely be marked as incorrect by your linter but it is - # valid Python. Simply append # type: ignore for most linters and - # noqa: F722 if you are using Flake8. - Description["A word that the bot will say."] # type: ignore # noqa: F722 + Description("A word that the bot will say.") # type: ignore # noqa: F722 ] ): # Returns the name of the user that initiated the interaction diff --git a/examples/basic_cogs/cogs.py b/examples/basic_cogs/cogs.py new file mode 100644 index 00000000..55eb228b --- /dev/null +++ b/examples/basic_cogs/cogs.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, List +from pincer import Client, Cog, command +from pincer.objects import MessageContext, Embed + + +class ErrorHandler(Cog): + @Client.event + async def on_command_error( + self, + ctx: MessageContext, + error: Exception, + args: List[Any], + kwargs: Dict[str, Any] + ): + return Embed( + "Oops...", + "An error occurred while trying to execute the " + f"`{ctx.interaction.data.name}` command! Please retry later!", + color=0xff0000 + ).add_field( + "Exception:", + f"```\n{type(error).__name__}:\n{error}\n```" + ) + + +class OnReadyCog(Cog): + @Client.event + async def on_ready(self): + print( + f"Started client on {self.client.bot}\n" + "Registered commands: " + ", ".join(self.client.chat_commands) + ) + + +class SayCog(Cog): + @command(description="Say something as the bot!") + async def say(self, ctx: MessageContext, message: str): + return Embed(description=f"{ctx.author.mention} said:\n{message}") diff --git a/examples/basic_cogs/cogs/error_handler.py b/examples/basic_cogs/cogs/error_handler.py deleted file mode 100644 index 5b12bcf3..00000000 --- a/examples/basic_cogs/cogs/error_handler.py +++ /dev/null @@ -1,19 +0,0 @@ -from pincer import Client -from pincer.objects import MessageContext, Embed - - -class ErrorHandler: - @Client.event - async def on_command_error(self, ctx: MessageContext, error: Exception): - return Embed( - "Oops...", - "An error occurred while trying to execute the " - f"`{ctx.command.app.name}` command! Please retry later!", - color=0xff0000 - ).add_field( - "Exception:", - f"```\n{type(error).__name__}:\n{error}\n```" - ) - - -setup = ErrorHandler diff --git a/examples/basic_cogs/cogs/on_ready.py b/examples/basic_cogs/cogs/on_ready.py deleted file mode 100644 index eb231b33..00000000 --- a/examples/basic_cogs/cogs/on_ready.py +++ /dev/null @@ -1,16 +0,0 @@ -from pincer import Client - - -class OnReadyCog: - def __init__(self, client: Client): - self.client = client - - @Client.event - async def on_ready(self): - print( - f"Started client on {self.client.bot}\n" - "Registered commands: " + ", ".join(self.client.chat_commands) - ) - - -setup = OnReadyCog diff --git a/examples/basic_cogs/cogs/say.py b/examples/basic_cogs/cogs/say.py deleted file mode 100644 index 59db33ca..00000000 --- a/examples/basic_cogs/cogs/say.py +++ /dev/null @@ -1,11 +0,0 @@ -from pincer import command -from pincer.objects import Embed, MessageContext - - -class SayCog: - @command(description="Say something as the bot!") - async def say(self, ctx: MessageContext, message: str): - return Embed(description=f"{ctx.author.user.mention} said:\n{message}") - - -setup = SayCog diff --git a/examples/basic_cogs/run.py b/examples/basic_cogs/run.py index 132f6e28..659f7c2c 100644 --- a/examples/basic_cogs/run.py +++ b/examples/basic_cogs/run.py @@ -1,18 +1,13 @@ -from glob import glob - from pincer import Client +from cogs import OnReadyCog, SayCog, ErrorHandler + class Bot(Client): def __init__(self, *args, **kwargs): - self.load_cogs() + super().load_cogs(OnReadyCog, SayCog, ErrorHandler) super().__init__(*args, **kwargs) - def load_cogs(self): - """Load all cogs from the `cogs` directory.""" - for cog in glob("cogs/*.py"): - self.load_cog(cog.replace("/", ".").replace("\\", ".")[:-3]) - if __name__ == "__main__": Bot("XXXYOURBOTTOKENHEREXXX").run() diff --git a/examples/tweet_generator/tweet_generator.py b/examples/tweet_generator/tweet_generator.py index 1ffdeb23..b2fbb5c3 100644 --- a/examples/tweet_generator/tweet_generator.py +++ b/examples/tweet_generator/tweet_generator.py @@ -12,12 +12,9 @@ # you need to manually download the font files and put them into the folder # ./examples/tweet_generator/ to make the script works using this link: # https://fonts.google.com/share?selection.family=Noto%20Sans:wght@400;700 -if not all( - font in os.listdir() - for font in [ - "NotoSans-Regular.ttf", - "NotoSans-Bold.ttf" - ] +if any( + font not in os.listdir() + for font in ["NotoSans-Regular.ttf", "NotoSans-Bold.ttf"] ): print( "You don't have the font files installed! you need to manually " diff --git a/packages/dev.txt b/packages/dev.txt index 69b309fe..480c091b 100644 --- a/packages/dev.txt +++ b/packages/dev.txt @@ -1,7 +1,7 @@ -coverage==6.2 +coverage==6.3.2 flake8==4.0.1 tox==3.24.4 -pre-commit==2.16.0 +pre-commit==2.17.0 pytest==6.2.5 pytest-cov==3.0.0 mypy==0.910 diff --git a/packages/img.txt b/packages/img.txt index 84ed6e43..36a2029f 100644 --- a/packages/img.txt +++ b/packages/img.txt @@ -1 +1,2 @@ Pillow==8.4.0 +types-Pillow==9.0.6 diff --git a/pincer/__init__.py b/pincer/__init__.py index 0bd44b67..f3c25a58 100644 --- a/pincer/__init__.py +++ b/pincer/__init__.py @@ -11,6 +11,7 @@ from ._config import GatewayConfig from .client import event_middleware, Client, Bot +from .cog import Cog from .commands import command, ChatCommandHandler from .exceptions import ( PincerError, InvalidPayload, UnhandledException, NoExportMethod, @@ -55,11 +56,11 @@ def __repr__(self) -> str: ) -version_info = VersionInfo(0, 15, 3) +version_info = VersionInfo(0, 16, 0) __version__ = repr(version_info) __all__ = ( - "BadRequestError", "Bot", "ChatCommandHandler", "Client", + "BadRequestError", "Bot", "ChatCommandHandler", "Client", "Cog", "CogAlreadyExists", "CogError", "CogNotFound", "CommandAlreadyRegistered", "CommandCooldownError", "CommandDescriptionTooLong", "CommandError", "CommandIsNotCoroutine", "CommandReturnIsEmpty", "DisallowedIntentsError", diff --git a/pincer/client.py b/pincer/client.py index 7adcfdea..fb9d480d 100644 --- a/pincer/client.py +++ b/pincer/client.py @@ -4,6 +4,7 @@ from __future__ import annotations import logging +import signal from asyncio import ( iscoroutinefunction, ensure_future, @@ -12,7 +13,6 @@ ) from collections import defaultdict from functools import partial -from importlib import import_module from inspect import isasyncgenfunction from typing import ( Any, @@ -20,23 +20,25 @@ List, Optional, Iterable, + OrderedDict, Tuple, Union, overload, TYPE_CHECKING, ) + + +from .cog import CogManager +from .commands.interactable import Interactable +from .objects.app.command import InteractableStructure + from . import __package__ from .commands import ChatCommandHandler from .core import HTTPClient from .core.gateway import GatewayInfo, Gateway -from .exceptions import ( - InvalidEventName, - TooManySetupArguments, - NoValidSetupMethod, - NoCogManagerReturnFound, - CogAlreadyExists, - CogNotFound, -) + +from .exceptions import InvalidEventName, GatewayConnectionError + from .middleware import middleware from .objects import ( Role, @@ -50,6 +52,8 @@ UserMessage, Connection, File, + StageInstance, + PrivacyLevel, ) from .objects.guild.channel import GroupDMChannel from .utils import APIObject @@ -57,12 +61,11 @@ from .utils.event_mgr import EventMgr from .utils.extraction import get_index from .utils.insertion import should_pass_cls, should_pass_gateway -from .utils.signature import get_params +from .utils.shards import calculate_shard_id from .utils.types import CheckFunction from .utils.types import Coro if TYPE_CHECKING: - from .objects.app import AppCommand from .utils.snowflake import Snowflake from .core.dispatch import GatewayDispatch from .objects.app.throttling import ThrottleInterface @@ -74,7 +77,7 @@ MiddlewareType = Optional[Union[Coro, Tuple[str, List[Any], Dict[str, Any]]]] -_event = Union[str, Coro] +_event = Union[str, InteractableStructure[None]] _events: Dict[str, Optional[Union[List[_event], _event]]] = defaultdict(list) @@ -159,7 +162,7 @@ async def wrapper(cls, gateway: Gateway, payload: GatewayDispatch): event_middleware(event)(middleware_) -class Client: +class Client(Interactable, CogManager): """The client is the main instance which is between the programmer and the discord API. @@ -201,6 +204,17 @@ def __init__( throttler: ThrottleInterface = DefaultThrottleHandler, reconnect: bool = True, ): + def sigint_handler(_signal, _frame): + _log.info("SIGINT received, shutting down...") + + # A print statement to make sure the user sees the message + print("Closing the client loop, this can take a few seconds...") + + create_task(self.http.close()) + if self.loop.is_running(): + self.loop.stop() + + signal.signal(signal.SIGINT, sigint_handler) if isinstance(intents, Iterable): intents = sum(intents) @@ -215,23 +229,27 @@ def __init__( self.bot: Optional[User] = None self.received_message = received or "Command arrived successfully!" self.http = HTTPClient(token) - APIObject.link(self) + APIObject.bind_client(self) self.throttler = throttler - self.event_mgr = EventMgr() async def get_gateway(): return GatewayInfo.from_dict(await self.http.get("gateway/bot")) - loop = get_event_loop() - self.gateway: GatewayInfo = loop.run_until_complete(get_gateway()) + self.loop = get_event_loop() + self.event_mgr = EventMgr(self.loop) + + self.gateway: GatewayInfo = self.loop.run_until_complete(get_gateway()) + self.shards: OrderedDict[int, Gateway] = OrderedDict() # The guild and channel value is only registered if the Client has the GUILDS # intent. self.guilds: Dict[Snowflake, Optional[Guild]] = {} self.channels: Dict[Snowflake, Optional[Channel]] = {} - ChatCommandHandler.managers[self.__module__] = self + ChatCommandHandler.managers.append(self) + + super().__init__() @property def chat_commands(self) -> List[str]: @@ -240,7 +258,7 @@ def chat_commands(self) -> List[str]: Get a list of chat command calls which have been registered in the :class:`~pincer.commands.ChatCommandHandler`\\. """ - return [cmd.app.name for cmd in ChatCommandHandler.register.values()] + return [cmd.metadata.name for cmd in ChatCommandHandler.register.values()] @property def guild_ids(self) -> List[Snowflake]: @@ -322,17 +340,23 @@ async def on_ready(self): "it gets treated as a command and can have a response." ) - _events[name].append(coroutine) - return coroutine + event = InteractableStructure(call=coroutine) + + _events[name].append(event) + return event @staticmethod - def get_event_coro(name: str) -> List[Optional[Coro]]: + def get_event_coro(name: str) -> List[Optional[InteractableStructure[None]]]: """get the coroutine for an event Parameters ---------- name : :class:`str` name of the event + + Returns + ------- + List[Optional[:class:`~pincer.objects.app.command.InteractableStructure`[None]]] """ calls = _events.get(name.strip().lower()) @@ -342,135 +366,22 @@ def get_event_coro(name: str) -> List[Optional[Coro]]: else [ call for call in calls - if iscoroutinefunction(call) or isasyncgenfunction(call) + if isinstance(call, InteractableStructure) ] ) - def load_cog(self, path: str, package: Optional[str] = None): - """Load a cog from a string path, setup method in COG may - optionally have a first argument which will contain the client! - - :Example usage: - - run.py - - .. code-block:: python3 - - from pincer import Client - - class MyClient(Client): - def __init__(self, *args, **kwargs): - self.load_cog("cogs.say") - super().__init__(*args, **kwargs) - - cogs/say.py - - .. code-block:: python3 - - from pincer import command - - class SayCommand: - @command() - async def say(self, message: str) -> str: - return message - - setup = SayCommand - - Parameters - ---------- - path : :class:`str` - The import path for the cog. - package : :class:`str` - The package name for relative based imports. - |default| :data:`None` - """ - - if ChatCommandHandler.managers.get(path): - raise CogAlreadyExists( - f"Cog `{path}` is trying to be loaded but already exists." - ) - - try: - module = import_module(path, package=package) - except ModuleNotFoundError: - raise CogNotFound(f"Cog `{path}` could not be found!") - - setup = getattr(module, "setup", None) - - if not callable(setup): - raise NoValidSetupMethod( - f"`setup` method was expected in `{path}` but none was found!" - ) - - args, params = [], get_params(setup) - - if len(params) == 1: - args.append(self) - elif (length := len(params)) > 1: - raise TooManySetupArguments( - f"Setup method in `{path}` requested {length} arguments " - f"but the maximum is 1!" - ) - - cog_manager = setup(*args) - - if not cog_manager: - raise NoCogManagerReturnFound( - f"Setup method in `{path}` didn't return a cog manager! " - "(Did you forget to return the cog?)" - ) - - ChatCommandHandler.managers[path] = cog_manager - @staticmethod - def get_cogs() -> Dict[str, Any]: - """Get a dictionary of all loaded cogs. - - The key/value pair is import path/cog class. - - Returns - ------- - Dict[:class:`str`, Any] - The dictionary of cogs - """ - return ChatCommandHandler.managers - - async def unload_cog(self, path: str): - """|coro| - - Unloads a currently loaded Cog - - Parameters - ---------- - path : :class:`str` - The path to the cog - - Raises - ------ - CogNotFound - When the cog is not in that path - """ - if not ChatCommandHandler.managers.get(path): - raise CogNotFound(f"Cog `{path}` could not be found!") - - to_remove: List[AppCommand] = [] - - for command in ChatCommandHandler.register.values(): - if not command: - continue - - if command.call.__module__ == path: - to_remove.append(command.app) - - await ChatCommandHandler(self).remove_commands(to_remove) - - @staticmethod - def execute_event(calls: List[Coro], gateway: Gateway, *args, **kwargs): + def execute_event( + events: List[InteractableStructure], + gateway: Gateway, + *args, + **kwargs + ): """Invokes an event. Parameters ---------- - calls: :class:`~pincer.utils.types.Coro` + calls: List[:class:`~pincer.objects.app.command.InteractableStructure`] The call (method) to which the event is registered. \\*args: @@ -480,24 +391,23 @@ def execute_event(calls: List[Coro], gateway: Gateway, *args, **kwargs): The named arguments for the event. """ - for call in calls: + for event in events: call_args = args - if should_pass_cls(call): + if should_pass_cls(event.call): call_args = ( - ChatCommandHandler.managers[call.__module__], + event.manager, *remove_none(args), ) - if should_pass_gateway(call): + if should_pass_gateway(event.call): call_args = (call_args[0], gateway, *call_args[1:]) - ensure_future(call(*call_args, **kwargs)) + ensure_future(event.call(*call_args, **kwargs)) def run(self): """Start the bot.""" - loop = get_event_loop() - ensure_future(self.start_shard(0, 1), loop=loop) - loop.run_forever() + ensure_future(self.start_shard(0, 1), loop=self.loop) + self.loop.run_forever() def run_autosharded(self): """ @@ -515,12 +425,10 @@ def run_shards(self, shards: Iterable, num_shards: int): num_shards: int The total amount of shards. """ - loop = get_event_loop() - for shard in shards: - ensure_future(self.start_shard(shard, num_shards), loop=loop) + ensure_future(self.start_shard(shard, num_shards), loop=self.loop) - loop.run_forever() + self.loop.run_forever() async def start_shard(self, shard: int, num_shards: int): """|coro| @@ -552,13 +460,67 @@ async def start_shard(self, shard: int, num_shards: int): } ) + self.shards[gateway.shard] = gateway create_task(gateway.start_loop()) - def __del__(self): - """Ensure close of the http client.""" + def get_shard( + self, + guild_id: Optional[Snowflake] = None, + num_shards: Optional[int] = None, + ) -> Gateway: + """ + Returns the shard receiving events for a specified guild_id. + + ``num_shards`` is inferred from the num_shards value for the first started + shard. If your shards do not all have the same ``num_shard`` value, you must + specify value to get the expected result. + + Parameters + ---------- + guild_id : Optional[:class:`~pincer.utils.snowflake.Snowflake`] + The guild_id of the shard to look for. If no guild id is provided, the + shard that receives dms will be returned. |default| :data:`None` + num_shards : Optional[:class:`int`] + The number of shards. If no number is provided, the value will default to + the num_shards for the first started shard. |default| :data:`None` + """ + if not self.shards: + raise GatewayConnectionError( + "The client has never connected to a gateway" + ) + if guild_id is None: + return self.shards[0] + if num_shards is None: + num_shards = next(iter(self.shards.values())).num_shards + return self.shards[calculate_shard_id(guild_id, num_shards)] + + @property + def is_closed(self) -> bool: + """ + Returns + ------- + bool + Whether the bot is closed. + """ + return self.loop.is_running() + + def close(self): + """ + Ensure close of the http client. + Allow for script execution to continue. + """ if hasattr(self, "http"): create_task(self.http.close()) + self.loop.stop() + + def __del__(self): + if self.loop.is_running(): + self.loop.stop() + + if not self.loop.is_closed(): + self.close() + async def handle_middleware( self, payload: GatewayDispatch, @@ -679,9 +641,6 @@ async def event_handler(self, gateway: Gateway, payload: GatewayDispatch): Parameters ---------- - _ : - Socket param, but this isn't required for this handler. So - it's just a filler parameter, doesn't matter what is passed. payload : :class:`~pincer.core.dispatch.GatewayDispatch` The payload sent from the Discord gateway, this contains the required data for the client to know what event it is and @@ -698,9 +657,6 @@ async def payload_event_handler( Parameters ---------- - _ : - Socket param, but this isn't required for this handler. So - it's just a filler parameter, doesn't matter what is passed. payload : :class:`~pincer.core.dispatch.GatewayDispatch` The payload sent from the Discord gateway, this contains the required data for the client to know what event it is and @@ -1132,5 +1088,125 @@ async def sticker_packs(self) -> AsyncIterator[StickerPack]: for pack in packs: yield StickerPack.from_dict(pack) + async def create_stage( + self, + channel_id: int, + topic: str, + privacy_level: Optional[PrivacyLevel] = None, + reason: Optional[str] = None, + ) -> StageInstance: + """|coro| + + Parameters + ---------- + channel_id : :class:`int` + The id of the Stage channel + topic : :class:`str` + The topic of the Stage instance (1-120 characters) + privacy_level : Optional[:class:`~pincer.objects.guild.stage.PrivacyLevel`] + The privacy level of the Stage instance (default :data:`GUILD_ONLY`) + reason : Optional[:class:`str`] + The reason for creating the Stage instance + + Returns + ------- + :class:`~pincer.objects.guild.stage.StageInstance` + The Stage instance created + """ + + data = { + "channel_id": channel_id, + "topic": topic, + "privacy_level": privacy_level, + } + + return await self.http.post( # type: ignore + "stage-instances", remove_none(data), headers={"reason": reason} + ) + + async def get_stage(self, _id: int) -> StageInstance: + """|coro| + Gets the stage instance associated with the Stage channel, if it exists + + Parameters + ---------- + _id : int + The ID of the stage to get + + Returns + ------- + :class:`~pincer.objects.guild.stage.StageInstance` + The stage instance + """ + return await StageInstance.from_id(self, _id) + + async def modify_stage( + self, + _id: int, + topic: Optional[str] = None, + privacy_level: Optional[PrivacyLevel] = None, + reason: Optional[str] = None, + ): + """|coro| + Updates fields of an existing Stage instance. + Requires the user to be a moderator of the Stage channel. + + Parameters + ---------- + _id : int + The ID of the stage to modify + topic : Optional[:class:`str`] + The topic of the Stage instance (1-120 characters) + privacy_level : Optional[:class:`~pincer.objects.guild.stage.PrivacyLevel`] + The privacy level of the Stage instance + reason : Optional[:class:`str`] + The reason for the modification + """ + + await self.http.patch( + f"stage-instances/{_id}", + remove_none({"topic": topic, "privacy_level": privacy_level}), + headers={"reason": reason}, + ) + + async def delete_stage(self, _id: int, reason: Optional[str] = None): + """|coro| + Deletes the Stage instance. + Requires the user to be a moderator of the Stage channel. + + Parameters + ---------- + _id : int + The ID of the stage to delete + reason : Optional[:class:`str`] + The reason for the deletion + """ + await self.http.delete( + f"stage-instances/{_id}", headers={"reason": reason} + ) + + async def crosspost_message(self, channel_id: int, message_id: int) -> UserMessage: + """|coro| + Crosspost a message in a News Channel to following channels. + + This endpoint requires the ``SEND_MESSAGES`` permission, + if the current user sent the message, or additionally the + ``MANAGE_MESSAGES`` permission, for all other messages, + to be present for the current user. + + Parameters + ---------- + channel_id : int + ID of the news channel that the message is in. + message_id : int + ID of the message to crosspost. + + Returns + ------- + :class:`~pincer.objects.message.UserMessage` + The crossposted message + """ + + return await self._http.post(f"channels/{channel_id}/{message_id}/crosspost") Bot = Client diff --git a/pincer/cog.py b/pincer/cog.py new file mode 100644 index 00000000..aee3063b --- /dev/null +++ b/pincer/cog.py @@ -0,0 +1,158 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from __future__ import annotations +from asyncio import ensure_future + +from importlib import reload, import_module +from inspect import isclass +from types import ModuleType +from typing import TYPE_CHECKING, List + +from .commands.chat_command_handler import ChatCommandHandler +from .commands.interactable import Interactable +from .exceptions import CogAlreadyExists + +if TYPE_CHECKING: + from typing import Type + from .client import Client + + +class CogManager: + """ + A class that can load and unload cogs + """ + + def load_cog(self, cog: Type[Cog]): + """Load a cog from a string path, setup method in COG may + optionally have a first argument which will contain the client! + + :Example usage: + + run.py + + .. code-block:: python3 + + from pincer import Client + from cogs.say import SayCommand + + class MyClient(Client): + def __init__(self, *args, **kwargs): + self.load_cog(SayCommand) + super().__init__(*args, **kwargs) + + cogs/say.py + + .. code-block:: python3 + + from pincer import command + + class SayCommand(Cog): + @command() + async def say(self, message: str) -> str: + return message + + Parameters + ---------- + cog : Type[:class:`~pincer.cog.Cog`] + The cog to load. + """ + if cog in ChatCommandHandler.managers: + raise CogAlreadyExists( + f"Cog `{cog}` is trying to be loaded but already exists." + ) + + cog_manager = cog(self) + + ChatCommandHandler.managers.append(cog_manager) + + def load_cogs(self, *cogs: Type[Cog]): + """ + Loads a list of cogs + + Parameters + ---------- + \\*cogs : Type[:class:`~pincer.cog.Cog`] + A list of cogs to load. + """ + for cog in cogs: + self.load_cog(cog) + + def load_module(self, module: ModuleType): + """Loads the cogs from a module recursively. + + Parameters + ---------- + module : :class:`~types.ModuleType` + The module to load. + """ + for item in vars(module).values(): + if isinstance(item, ModuleType): + self.load_module(item) + elif item is not Cog and isclass(item) and issubclass(item, Cog): + self.load_cog(item) + + def reload_cogs(self): + """Reloads all of the loaded cogs""" + + modules = [] + + for cog in self.cogs: + cog.unassign() + + mod = import_module(type(cog).__module__) + if mod not in modules: + modules.append(mod) + + for mod in modules: + reload(mod) + + for cog in self.cogs: + for mod in modules: + cog = getattr(mod, type(cog).__name__, None) + if cog: + self.load_cog(cog) + + ChatCommandHandler.has_been_initialized = False + ensure_future(ChatCommandHandler(self).initialize()) + + @property + def cogs(self) -> List[Cog]: + """Get a dictionary of all loaded cogs. + + The key/value pair is import path/cog class. + + Returns + ------- + List[:class:`~pincer.cog.Cog`] + The list of cogs + """ + return [ + manager for manager in ChatCommandHandler.managers + if isinstance(manager, Cog) + ] + + +class Cog(Interactable): + """A cog object + This is an object that can register commands and message components that isn't a + client. It also can be loaded and unloaded at runtime so commands can be changed + without restarting the bot. + """ + + def __init__(self, client: Client) -> None: + self.client = client + + super().__init__() + + @classmethod + def name(cls) -> str: + """ + Returns a unique name for this cog. + + Returns + ------- + str + A unique name for this cog. + """ + return f"{cls.__module__}.{cls.__name__}" diff --git a/pincer/commands/__init__.py b/pincer/commands/__init__.py index a8ea6d78..ea8a15a5 100644 --- a/pincer/commands/__init__.py +++ b/pincer/commands/__init__.py @@ -1,7 +1,8 @@ # Copyright Pincer 2021-Present # Full MIT License can be found in `LICENSE` at the project root. -from .commands import command, user_command, message_command, ChatCommandHandler +from .commands import command, user_command, message_command +from .chat_command_handler import ChatCommandHandler from .arg_types import ( CommandArg, Description, @@ -17,11 +18,13 @@ component, button, select_menu, LinkButton ) from .groups import Group, Subgroup +from .interactable import Interactable, INTERACTION_REGISTERS __all__ = ( "ActionRow", "Button", "ButtonStyle", "ChannelTypes", "ChatCommandHandler", "Choice", "Choices", "CommandArg", - "ComponentHandler", "Description", "Group", "LinkButton", "MaxValue", - "MinValue", "Modifier", "SelectMenu", "SelectOption", "Subgroup", "button", - "command", "component", "message_command", "select_menu", "user_command" + "ComponentHandler", "Description", "Group", "INTERACTION_REGISTERS", + "Interactable", "LinkButton", "MaxValue", "MinValue", "Modifier", + "SelectMenu", "SelectOption", "Subgroup", "button", "command", "component", + "message_command", "select_menu", "user_command" ) diff --git a/pincer/commands/arg_types.py b/pincer/commands/arg_types.py index 6822c59d..db1e5b1b 100644 --- a/pincer/commands/arg_types.py +++ b/pincer/commands/arg_types.py @@ -1,11 +1,14 @@ # Copyright Pincer 2021-Present # Full MIT License can be found in `LICENSE` at the project root. +import logging from typing import Any, List, Tuple, Union, T from ..utils.types import MISSING from ..objects.app.command import AppCommandOptionChoice +_log = logging.getLogger(__name__) + class _CommandTypeMeta(type): def __getitem__(cls, args: Union[Tuple, Any]): @@ -19,16 +22,22 @@ class CommandArg(metaclass=_CommandTypeMeta): """ Holds the parameters of an application command option + .. note:: + Deprecated. :class:`typing.Annotated` or :class:`typing_extensions.Annotated` + should be used instead. See + https://docs.pincer.dev/en/stable/interactions.html#arguments for more + information. + .. code-block:: python3 - CommandArg[ + Annotated[ # This is the type of command. # Supported types are str, int, bool, float, User, Channel, and Role int, # The modifiers to the command go here - Description["Pick a number 1-10"], - MinValue[1], - MaxValue[10] + Description("Pick a number 1-10"), + MinValue(1), + MaxValue(10) ] Parameters @@ -42,6 +51,12 @@ class CommandArg(metaclass=_CommandTypeMeta): def __init__(self, command_type, *args): self.command_type = command_type self.modifiers = args + _log.warning( + "CommandArg is deprecated and will be removed in future releases." + " `typing.Annotated`/`typing_extensions.Annotated.` should be used instead." + " See https://docs.pincer.dev/en/stable/interactions.html#arguments for" + " more information." + ) def get_arg(self, arg_type: T) -> T: for arg in self.modifiers: @@ -55,6 +70,21 @@ class Modifier(metaclass=_CommandTypeMeta): """ Modifies a CommandArg by being added to :class:`~pincer.commands.arg_types.CommandArg`'s args. + + Modifiers go inside an :class:`typing.Annotated` type hint. + + .. code-block:: python3 + + Annotated[ + # This is the type of command. + # Supported types are str, int, bool, float, User, Channel, and Role + int, + # The modifiers to the command go here + Description("Pick a number 1-10"), + MinValue(1), + MaxValue(10) + ] + """ @@ -65,9 +95,9 @@ class Description(Modifier): .. code-block:: python3 # Creates an int argument with the description "example description" - CommandArg[ + Annotated[ int, - Description["example description"] + Description("example description") ] Parameters @@ -89,10 +119,10 @@ class Choice(Modifier): .. code-block:: python3 - Choices[ - Choice["First Number", 10], - Choice["Second Number", 20] - ] + Choices( + Choice("First Number", 10), + Choice("Second Number", 20) + ) Parameters ---------- @@ -113,13 +143,13 @@ class Choices(Modifier): .. code-block:: python3 - CommandArg[ + Annotated[ int, - Choices[ - Choice["First Number", 10], + Choices( + Choice("First Number", 10), 20, 50 - ] + ) ] Parameters @@ -153,14 +183,14 @@ class ChannelTypes(Modifier): .. code-block:: python3 - CommandArg[ + Annotated[ Channel, # The user will only be able to choice between GUILD_TEXT and GUILD_TEXT channels. - ChannelTypes[ + ChannelTypes( ChannelType.GUILD_TEXT, ChannelType.GUILD_VOICE - ] + ) ] Parameters @@ -182,10 +212,10 @@ class MaxValue(Modifier): .. code-block:: python3 - CommandArg[ + Annotated[ int, # The user can't pick a number above 10 - MaxValue[10] + MaxValue(10) ] Parameters @@ -207,10 +237,10 @@ class MinValue(Modifier): .. code-block:: python3 - CommandArg[ + Annotated[ int, # The user can't pick a number below 10 - MinValue[10] + MinValue(10) ] Parameters diff --git a/pincer/commands/chat_command_handler.py b/pincer/commands/chat_command_handler.py new file mode 100644 index 00000000..1772f2b0 --- /dev/null +++ b/pincer/commands/chat_command_handler.py @@ -0,0 +1,439 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import logging +from asyncio import gather + +from ..utils.types import MISSING, Singleton + +from ..exceptions import ForbiddenError +from ..objects.guild.guild import Guild +from ..objects.app.command import AppCommand, AppCommandOption +from ..objects.app.command_types import AppCommandOptionType, AppCommandType + +if TYPE_CHECKING: + from typing import List, Dict, Optional, ValuesView, Union + from .interactable import Interactable + from ..client import Client + from ..utils.snowflake import Snowflake + from ..objects.app.command import InteractableStructure + +_log = logging.getLogger(__name__) + + +class ChatCommandHandler(metaclass=Singleton): + """Singleton containing methods used to handle various commands + + The register and built_register + ------------------------------- + I found the way Discord expects commands to be registered to be very different than + how you want to think about command registration. i.e. Discord wants nesting but we + don't want any nesting. Nesting makes it hard to think about commands and also will + increase lookup time. + The way this problem is avoided is by storing a version of the commands that we can + deal with as library developers and a version of the command that Discord thinks we + should provide. That is where the register and the built_register help simplify the + design of the library. + The register is simply where the "Pincer version" of commands gets saved to memory. + The built_register is where the version of commands that Discord requires is saved. + The register allows for O(1) lookups by storing commands in a Python dictionary. It + does cost some memory to save two copies in the current iteration of the system but + we should be able to drop the built_register in runtime if we want to. I don't feel + that lost maintainability from this is optimal. We can index by in O(1) by checking + the register but can still use the built_register if we need to do a nested lookup. + + Attributes + ---------- + client: :class:`Client` + The client object + managers: Dict[:class:`str`, :class:`~typing.Any`] + Dictionary of managers + register: Dict[:class:`str`, :class:`~pincer.objects.app.command.InteractableStructure`[:class:`~pincer.objects.app.command.AppCommand`]] + Dictionary of ``InteractableStructure`` + built_register: Dict[:class:`str`, :class:`~pincer.objects.app.command.AppCommand`]] + Dictionary of ``InteractableStructure`` where the commands are converted to + the format that Discord expects for sub commands and sub command groups. + """ # noqa: E501 + + has_been_initialized = False + managers: List[Interactable] = [] + register: Dict[str, InteractableStructure[AppCommand]] = {} + built_register: Dict[str, AppCommand] = {} + + # Endpoints: + __get = "/commands" + __delete = "/commands/{command.id}" + __update = "/commands/{command.id}" + __add = "/commands" + __add_guild = "/guilds/{command.guild_id}/commands" + __get_guild = "/guilds/{guild_id}/commands" + __update_guild = "/guilds/{command.guild_id}/commands/{command.id}" + __delete_guild = "/guilds/{command.guild_id}/commands/{command.id}" + + def __init__(self, client: Client): + self.client = client + self._api_commands: List[AppCommand] = [] + _log.debug( + "%i commands registered.", len(ChatCommandHandler.register) + ) + + self.__prefix = f"applications/{self.client.bot.id}" + + async def get_commands(self) -> List[AppCommand]: + """|coro| + + Get a list of app commands from Discord + + Returns + ------- + List[:class:`~pincer.objects.app.command.AppCommand`] + List of commands. + """ + # TODO: Update if discord adds bulk get guild commands + guild_commands = await gather(*( + self.client.http.get( + self.__prefix + self.__get_guild.format( + guild_id=guild.id if isinstance(guild, Guild) else guild + ) + ) for guild in self.client.guilds + )) + return list( + map( + AppCommand.from_dict, + await self.client.http.get(self.__prefix + self.__get) + + [cmd for guild in guild_commands for cmd in guild], + ) + ) + + async def remove_command(self, cmd: AppCommand): + """|coro| + + Remove a specific command + + Parameters + ---------- + cmd : :class:`~pincer.objects.app.command.AppCommand` + What command to delete + """ + # TODO: Update if discord adds bulk delete commands + if cmd.guild_id: + _log.info( + "Removing command `%s` with guild id %d from Discord", + cmd.name, + cmd.guild_id, + ) + else: + _log.info("Removing global command `%s` from Discord", cmd.name) + + remove_endpoint = self.__delete_guild if cmd.guild_id else self.__delete + + await self.client.http.delete( + self.__prefix + remove_endpoint.format(command=cmd) + ) + + async def add_command(self, cmd: AppCommand): + """|coro| + + Add an app command + + Parameters + ---------- + cmd : :class:`~pincer.objects.app.command.AppCommand` + Command to add + """ + _log.info("Updated or registered command `%s` to Discord", cmd.name) + + add_endpoint = self.__add + + if cmd.guild_id: + add_endpoint = self.__add_guild.format(command=cmd) + + await self.client.http.post( + self.__prefix + add_endpoint, data=cmd.to_dict() + ) + + async def add_commands(self, commands: List[AppCommand]): + """|coro| + + Add a list of app commands + + Parameters + ---------- + commands : List[:class:`~pincer.objects.app.command.AppCommand`] + List of command objects to add + """ + await gather(*map(self.add_command, commands)) + + @staticmethod + def __build_local_commands(): + """Builds the commands into the format that Discord expects. See class info + for the reasoning. + """ + + # Reset the built register + ChatCommandHandler.built_register = {} + + for cmd in ChatCommandHandler.register.values(): + + if cmd.sub_group: + # If a command has a sub_group, it must be nested 2 levels deep. + # + # command + # subcommand-group + # subcommand + # + # The children of the subcommand-group object are being set to include + # `cmd` If that subcommand-group object does not exist, it will be + # created here. The same goes for the top-level command. + # + # First make sure the command exists. This command will hold the + # subcommand-group for `cmd`. + + # `key` represents the hash value for the top-level command that will + # hold the subcommand. + key = _hash_app_command_params( + cmd.group.name, + cmd.metadata.guild_id, + AppCommandType.CHAT_INPUT, + None, + None, + ) + + if key not in ChatCommandHandler.built_register: + ChatCommandHandler.built_register[key] = AppCommand( + name=cmd.group.name, + description=cmd.group.description, + type=AppCommandType.CHAT_INPUT, + guild_id=cmd.metadata.guild_id, + options=[] + ) + + # The top-level command now exists. A subcommand group now if placed + # inside the top-level command. This subcommand group will hold `cmd`. + + children = ChatCommandHandler.built_register[key].options + + sub_command_group = AppCommandOption( + name=cmd.sub_group.name, + description=cmd.sub_group.description, + type=AppCommandOptionType.SUB_COMMAND_GROUP, + options=[] + ) + + # This for-else makes sure that sub_command_group will hold a reference + # to the subcommand group that we want to modify to hold `cmd` + + for cmd_in_children in children: + if ( + cmd_in_children.name == sub_command_group.name + and cmd_in_children.description == sub_command_group.description + and cmd_in_children.type == sub_command_group.type + ): + sub_command_group = cmd_in_children + break + else: + children.append(sub_command_group) + + sub_command_group.options.append(AppCommandOption( + name=cmd.metadata.name, + description=cmd.metadata.description, + type=AppCommandOptionType.SUB_COMMAND, + options=cmd.metadata.options, + )) + + continue + + if cmd.group: + # Any command at this point will only have one level of nesting. + # + # Command + # subcommand + # + # A subcommand object is what is being generated here. If there is no + # top level command, it will be created here. + + # `key` represents the hash value for the top-level command that will + # hold the subcommand. + + key = _hash_app_command_params( + cmd.group.name, + cmd.metadata.guild_id, + AppCommandOptionType.SUB_COMMAND, + None, + None + ) + + if key not in ChatCommandHandler.built_register: + ChatCommandHandler.built_register[key] = AppCommand( + name=cmd.group.name, + description=cmd.group.description, + type=AppCommandOptionType.SUB_COMMAND, + guild_id=cmd.metadata.guild_id, + options=[] + ) + + # No checking has to be done before appending `cmd` since it is the + # lowest level. + ChatCommandHandler.built_register[key].options.append( + AppCommandOption( + name=cmd.metadata.name, + description=cmd.metadata.description, + type=AppCommandType.CHAT_INPUT, + options=cmd.metadata.options + ) + ) + + continue + + # All single-level commands are registered here. + ChatCommandHandler.built_register[ + _hash_interactable_structure(cmd) + ] = cmd.metadata + + @staticmethod + def get_local_registered_commands() -> ValuesView[AppCommand]: + return ChatCommandHandler.built_register.values() + + async def __get_existing_commands(self): + """|coro| + + Get AppCommand objects for all commands registered to discord. + """ + try: + self._api_commands = await self.get_commands() + except ForbiddenError: + logging.error("Cannot retrieve slash commands, skipping...") + return + + async def __remove_unused_commands(self): + """|coro| + + Remove commands that are registered by discord but not in use + by the current client + """ + local_registered_commands = self.get_local_registered_commands() + + def should_be_removed(target: AppCommand) -> bool: + # Commands have endpoints based on their `name` amd `guild_id`. Other + # parameters can be updated instead of deleting and re-registering the + # command. + return all( + target.name != reg_cmd.name + and target.guild_id != reg_cmd.guild_id + for reg_cmd in local_registered_commands + ) + + # NOTE: Cannot be generator since it can't be consumed due to lines 743-745 + to_remove = [*filter(should_be_removed, self._api_commands)] + + await gather( + *(self.remove_command(cmd) for cmd in to_remove) + ) + + self._api_commands = [ + cmd for cmd in self._api_commands + if cmd not in to_remove + ] + + async def __add_commands(self): + """|coro| + Add all new commands which have been registered by the decorator to Discord. + + .. code-block:: + + Because commands have unique names within a type and scope, we treat POST + requests for new commands as upserts. That means making a new command with + an already-used name for your application will update the existing command. + ``_ + + Therefore, we don't need to use a separate loop for updating and adding + commands. + """ + for command in self.get_local_registered_commands(): + if command not in self._api_commands: + await self.add_command(command) + + async def initialize(self): + """|coro| + + Call methods of this class to refresh all app commands + """ + if ChatCommandHandler.has_been_initialized: + # Only first shard should be initialized. + return + + ChatCommandHandler.has_been_initialized = True + + self.__build_local_commands() + await self.__get_existing_commands() + await self.__remove_unused_commands() + await self.__add_commands() + + +def _hash_interactable_structure(interactable: InteractableStructure[AppCommand]): + return _hash_app_command( + interactable.metadata, + interactable.group, + interactable.sub_group + ) + + +def _hash_app_command( + command: AppCommand, + group: Optional[str], + sub_group: Optional[str] +) -> int: + """ + See :func:`~pincer.commands.commands._hash_app_command_params` for information. + """ + return _hash_app_command_params( + command.name, + command.guild_id, + command.type, + group, + sub_group + ) + + +def _hash_app_command_params( + name: str, + guild_id: Union[Snowflake, None, MISSING], + app_command_type: AppCommandType, + group: Optional[str], + sub_group: Optional[str] +) -> int: + """ + The group layout in Pincer is very different from what discord has on their docs. + You can think of the Pincer group layout like this: + + name: The name of the function that is being called. + + group: The :class:`~pincer.commands.groups.Group` object that this function is + using. + sub_option: The :class:`~pincer.commands.groups.Subgroup` object that this + functions is using. + + Abstracting away this part of the Discord API allows for a much cleaner + transformation between what users want to input and what commands Discord + expects. + + Parameters + ---------- + name : str + The name of the function for the command + guild_id : Union[:class:`~pincer.utils.snowflake.Snowflake`, None, MISSING] + The ID of a guild, None, or MISSING. + app_command_type : :class:`~pincer.objects.app.command_types.AppCommandType` + The app command type of the command. NOT THE OPTION TYPE. + group : str + The highest level of organization the command is it. This should always be the + name of the base command. :data:`None` or :data:`MISSING` if not there. + sub_option : str + The name of the group that holds the lowest level of options. :data:`None` or + :data:`MISSING` if not there. + """ + return hash((name, guild_id, app_command_type, group, sub_group)) diff --git a/pincer/commands/commands.py b/pincer/commands/commands.py index 739fc38d..a36c8931 100644 --- a/pincer/commands/commands.py +++ b/pincer/commands/commands.py @@ -5,13 +5,15 @@ import logging import re -from asyncio import iscoroutinefunction, gather +from asyncio import iscoroutinefunction from functools import partial from inspect import Signature, isasyncgenfunction, _empty -from typing import TYPE_CHECKING, Union, List, ValuesView - +from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, List from . import __package__ +from .chat_command_handler import ( + ChatCommandHandler, _hash_app_command_params +) from ..commands.arg_types import ( ChannelTypes, CommandArg, @@ -22,6 +24,7 @@ ) from ..commands.groups import Group, Subgroup from ..utils.snowflake import Snowflake +from ..utils.types import APINullable, MISSING from ..exceptions import ( CommandIsNotCoroutine, CommandAlreadyRegistered, @@ -30,7 +33,6 @@ CommandDescriptionTooLong, InvalidCommandGuild, InvalidCommandName, - ForbiddenError, ) from ..objects import ( ThrottleScope, @@ -38,23 +40,20 @@ Role, User, Channel, - Guild, Mentionable, MessageContext, ) from ..objects.app import ( AppCommandOptionType, AppCommandOption, - ClientCommandStructure, + InteractableStructure, AppCommandType, ) -from ..utils import get_index, should_pass_ctx +from ..utils import should_pass_ctx from ..utils.signature import get_signature_and_params -from ..utils.types import MISSING -from ..utils.types import Singleton if TYPE_CHECKING: - from typing import Any, Optional, Dict + from typing import Optional REGULAR_COMMAND_NAME_REGEX = re.compile(r"[\w\- ]{1,32}$") CHAT_INPUT_COMMAND_NAME_REGEX = re.compile(r"^[a-z0-9_-]{1,32}$") @@ -76,6 +75,8 @@ if TYPE_CHECKING: from ..client import Client +T = TypeVar("T") + def command( func=None, @@ -114,13 +115,13 @@ async def test_command( amount: int, name: CommandArg[ str, - Description["Do something cool"], - Choices[Choice["first value", 1], 5] + Description("Do something cool"), + Choices(Choice("first value", 1), 5) ], optional_int: CommandArg[ int, - MinValue[10], - MaxValue[100], + MinValue(10), + MaxValue(100), ] = 50 ): return Message( @@ -136,6 +137,8 @@ async def test_command( :class:`~pincer.objects.app.interaction_flags.InteractionFlags`, :class:`~pincer.commands.arg_types.Choices`, :class:`~pincer.commands.arg_types.Choice`, + :class:`typing_extensions.Annotated` (Python 3.8), + :class:`typing.Annotated` (Python 3.9+), :class:`~pincer.commands.arg_types.CommandArg`, :class:`~pincer.commands.arg_types.Description`, :class:`~pincer.commands.arg_types.MinValue`, @@ -228,7 +231,16 @@ async def test_command( if annotation == MessageContext and idx == 1: return - if type(annotation) is not CommandArg: + argument_type = None + if type(annotation) is CommandArg: + argument_type = annotation.command_type + # isinstance and type don't work for Annotated. This is the best way 💀 + elif hasattr(annotation, "__metadata__"): + # typing.get_origin doesn't work in 3.9+ for some reason. Maybe they forgor + # to implement it. + argument_type = annotation.__origin__ + + if not argument_type: if annotation in _options_type_link: options.append( AppCommandOption( @@ -242,16 +254,27 @@ async def test_command( # TODO: Write better exception raise InvalidArgumentAnnotation( - "Type must be CommandArg or other valid type" + "Type must be Annotated or other valid type" ) - command_type = _options_type_link[annotation.command_type] + command_type = _options_type_link[argument_type] + + def get_arg(t: T) -> APINullable[T]: + if type(annotation) is CommandArg: + return annotation.get_arg(t) + elif hasattr(annotation, "__metadata__"): + for obj in annotation.__metadata__: + if isinstance(obj, t): + return obj.get_payload() + return MISSING + argument_description = ( - annotation.get_arg(Description) or "Description not set" + get_arg(Description) or "Description not set" ) - choices = annotation.get_arg(Choices) - if choices is not MISSING and annotation.command_type not in { + choices = get_arg(Choices) + + if choices is not MISSING and argument_type not in { int, float, str, @@ -263,31 +286,31 @@ async def test_command( for choice in choices: if ( isinstance(choice.value, int) - and annotation.command_type is float + and argument_type is float ): continue - if not isinstance(choice.value, annotation.command_type): + if not isinstance(choice.value, argument_type): raise InvalidArgumentAnnotation( "Choice value must match the command type" ) - channel_types = annotation.get_arg(ChannelTypes) + channel_types = get_arg(ChannelTypes) if ( channel_types is not MISSING - and annotation.command_type is not Channel + and argument_type is not Channel ): raise InvalidArgumentAnnotation( "ChannelTypes are only available for Channels" ) - max_value = annotation.get_arg(MaxValue) - min_value = annotation.get_arg(MinValue) + max_value = get_arg(MaxValue) + min_value = get_arg(MinValue) for i, value in enumerate((min_value, max_value)): if ( value is not MISSING - and annotation.command_type is not int - and annotation.command_type is not float + and argument_type is not int + and argument_type is not float ): t = ("MinValue", "MaxValue") raise InvalidArgumentAnnotation( @@ -399,6 +422,17 @@ async def test_user_command( Not a valid argument type, Annotations must consist of name and value """ # noqa: E501 + if func is None: + return partial( + user_command, + name=name, + enable_default=enable_default, + guild=guild, + cooldown=cooldown, + cooldown_scale=cooldown_scale, + cooldown_scope=cooldown_scope, + ) + return register_command( func=func, app_command_type=AppCommandType.USER, @@ -478,6 +512,17 @@ async def test_message_command( Not a valid argument type, Annotations must consist of name and value """ # noqa: E501 + if func is None: + return partial( + message_command, + name=name, + enable_default=enable_default, + guild=guild, + cooldown=cooldown, + cooldown_scale=cooldown_scale, + cooldown_scope=cooldown_scope, + ) + return register_command( func=func, app_command_type=AppCommandType.MESSAGE, @@ -491,8 +536,7 @@ async def test_message_command( def register_command( - func=None, # Missing typehint? - *, + func: Callable[..., Any] = None, app_command_type: Optional[AppCommandType] = None, name: Optional[str] = None, description: Optional[str] = MISSING, @@ -504,20 +548,6 @@ def register_command( command_options=MISSING, # Missing typehint? parent: Optional[Union[Group, Subgroup]] = MISSING ): - if func is None: - return partial( - register_command, - name=name, - app_command_type=app_command_type, - description=description, - enable_default=enable_default, - guild=guild, - cooldown=cooldown, - cooldown_scale=cooldown_scale, - cooldown_scope=cooldown_scope, - parent=parent - ) - cmd = name or func.__name__ if not re.match(REGULAR_COMMAND_NAME_REGEX, cmd): @@ -564,16 +594,19 @@ def register_command( f"registered by `{reg.call.__name__}`." ) - ChatCommandHandler.register[ - _hash_app_command_params(cmd, guild_id, app_command_type, group, sub_group) - ] = ClientCommandStructure( + _log.info( + f"Registered command `{cmd}` to `{func.__name__}` locally." + ) + + interactable = InteractableStructure( call=func, cooldown=cooldown, cooldown_scale=cooldown_scale, cooldown_scope=cooldown_scope, + manager=None, group=group, sub_group=sub_group, - app=AppCommand( + metadata=AppCommand( name=cmd, description=description, type=app_command_type, @@ -583,431 +616,14 @@ def register_command( ), ) - _log.info(f"Registered command `{cmd}` to `{func.__name__}` locally.") - return func - - -class ChatCommandHandler(metaclass=Singleton): - """Singleton containing methods used to handle various commands - - The register and built_register - ------------------------------- - I found the way Discord expects commands to be registered to be very different than - how you want to think about command registration. i.e. Discord wants nesting but we - don't want any nesting. Nesting makes it hard to think about commands and also will - increase lookup time. - The way this problem is avoided is by storing a version of the commands that we can - deal with as library developers and a version of the command that Discord thinks we - should provide. That is where the register and the built_register help simplify the - design of the library. - The register is simply where the "Pincer version" of commands gets saved to memory. - The built_register is where the version of commands that Discord requires is saved. - The register allows for O(1) lookups by storing commands in a Python dictionary. It - does cost some memory to save two copies in the current iteration of the system but - we should be able to drop the built_register in runtime if we want to. I don't feel - that lost maintainability from this is optimal. We can index by in O(1) by checking - the register but can still use the built_register if we need to do a nested lookup. - - Attributes - ---------- - client: :class:`Client` - The client object - managers: Dict[:class:`str`, :class:`~typing.Any`] - Dictionary of managers - register: Dict[:class:`str`, :class:`~pincer.objects.app.command.ClientCommandStructure`] - Dictionary of ``ClientCommandStructure`` - built_register: Dict[:class:`str`, :class:`~pincer.objects.app.command.ClientCommandStructure`] - Dictionary of ``ClientCommandStructure`` where the commands are converted to - the format that Discord expects for sub commands and sub command groups. - """ # noqa: E501 - - has_been_initialized = False - managers: Dict[str, Any] = {} - register: Dict[str, ClientCommandStructure] = {} - built_register: Dict[str, AppCommand] = {} - - # Endpoints: - __get = "/commands" - __delete = "/commands/{command.id}" - __update = "/commands/{command.id}" - __add = "/commands" - __add_guild = "/guilds/{command.guild_id}/commands" - __get_guild = "/guilds/{guild_id}/commands" - __update_guild = "/guilds/{command.guild_id}/commands/{command.id}" - __delete_guild = "/guilds/{command.guild_id}/commands/{command.id}" - - def __init__(self, client: Client): - self.client = client - self._api_commands: List[AppCommand] = [] - logging.debug( - "%i commands registered.", len(ChatCommandHandler.register.items()) - ) - self.client.throttler.throttle = dict( - map( - lambda cmd: (cmd.call, {}), ChatCommandHandler.register.values() - ) - ) - - self.__prefix = f"applications/{self.client.bot.id}" - - async def get_commands(self) -> List[AppCommand]: - """|coro| - - Get a list of app commands from Discord - - Returns - ------- - List[:class:`~pincer.objects.app.command.AppCommand`] - List of commands. - """ - # TODO: Update if discord adds bulk get guild commands - guild_commands = await gather( - *map( - lambda guild: self.client.http.get( - self.__prefix - + self.__get_guild.format( - guild_id=guild.id if isinstance(guild, Guild) else guild - ) - ), - self.client.guilds, - ) - ) - return list( - map( - AppCommand.from_dict, - await self.client.http.get(self.__prefix + self.__get) - + [cmd for guild in guild_commands for cmd in guild], - ) - ) - - async def remove_command(self, cmd: AppCommand): - """|coro| - - Remove a specific command - - Parameters - ---------- - cmd : :class:`~pincer.objects.app.command.AppCommand` - What command to delete - """ - # TODO: Update if discord adds bulk delete commands - if cmd.guild_id: - _log.info( - "Removing command `%s` with guild id %d from Discord", - cmd.name, - cmd.guild_id, - ) - else: - _log.info("Removing global command `%s` from Discord", cmd.name) - - remove_endpoint = self.__delete_guild if cmd.guild_id else self.__delete - - await self.client.http.delete( - self.__prefix + remove_endpoint.format(command=cmd) - ) - - async def add_command(self, cmd: AppCommand): - """|coro| - - Add an app command - - Parameters - ---------- - cmd : :class:`~pincer.objects.app.command.AppCommand` - Command to add - """ - _log.info("Updated or registered command `%s` to Discord", cmd.name) - - add_endpoint = self.__add - - if cmd.guild_id: - add_endpoint = self.__add_guild.format(command=cmd) - - await self.client.http.post( - self.__prefix + add_endpoint, data=cmd.to_dict() - ) - - async def add_commands(self, commands: List[AppCommand]): - """|coro| - - Add a list of app commands - - Parameters - ---------- - commands : List[:class:`~pincer.objects.app.command.AppCommand`] - List of command objects to add - """ - await gather(*map(lambda cmd: self.add_command(cmd), commands)) - - def __build_local_commands(self): - """Builds the commands into the format that Discord expects. See class info - for the reasoning. - """ - for cmd in ChatCommandHandler.register.values(): - - if cmd.sub_group: - # If a command has a sub_group, it must be nested to levels deep. - # - # command - # subcommand-group - # subcommand - # - # The children of the subcommand-group object are being set to include - # `cmd` If that subcommand-group object does not exist, it will be - # created here. The same goes for the top-level command. - # - # First make sure the command exists. This command will hold the - # subcommand-group for `cmd`. - - # `key` represents the hash value for the top-level command that will - # hold the subcommand. - key = _hash_app_command_params( - cmd.group.name, - cmd.app.guild_id, - AppCommandType.CHAT_INPUT, - None, - None, - ) - - if key not in ChatCommandHandler.built_register: - ChatCommandHandler.built_register[key] = AppCommand( - name=cmd.group.name, - description=cmd.group.description, - type=AppCommandType.CHAT_INPUT, - guild_id=cmd.app.guild_id, - options=[] - ) - - # The top-level command now exists. A subcommand group now if placed - # inside the top-level command. This subcommand group will hold `cmd`. - - children = ChatCommandHandler.built_register[key].options - - sub_command_group = AppCommandOption( - name=cmd.sub_group.name, - description=cmd.sub_group.description, - type=AppCommandOptionType.SUB_COMMAND_GROUP, - options=[] - ) - - # This for-else makes sure that sub_command_group will hold a reference - # to the subcommand group that we want to modify to hold `cmd` - - for cmd_in_children in children: - if ( - cmd_in_children.name == sub_command_group.name - and cmd_in_children.description == sub_command_group.description - and cmd_in_children.type == sub_command_group.type - ): - sub_command_group = cmd_in_children - break - else: - children.append(sub_command_group) - - sub_command_group.options.append(AppCommandOption( - name=cmd.app.name, - description=cmd.app.description, - type=AppCommandOptionType.SUB_COMMAND, - options=cmd.app.options, - )) - - continue - - if cmd.group: - # Any command at this point will only have one level of nesting. - # - # Command - # subcommand - # - # A subcommand object is what is being generated here. If there is no - # top level command, it will be created here. - - # `key` represents the hash value for the top-level command that will - # hold the subcommand. - - key = _hash_app_command_params( - cmd.group.name, - cmd.app.guild_id, - AppCommandOptionType.SUB_COMMAND, - None, - None - ) - - if key not in ChatCommandHandler.built_register: - ChatCommandHandler.built_register[key] = AppCommand( - name=cmd.group.name, - description=cmd.group.description, - type=AppCommandOptionType.SUB_COMMAND, - guild_id=cmd.app.guild_id, - options=[] - ) - - # No checking has to be done before appending `cmd` since it is the - # lowest level. - ChatCommandHandler.built_register[key].options.append( - AppCommandOption( - name=cmd.app.name, - description=cmd.app.description, - type=AppCommandType.CHAT_INPUT, - options=cmd.app.options - ) - ) - - continue - - # All single-level commands are registered here. - ChatCommandHandler.built_register[ - _hash_app_command(cmd.app, cmd.group, cmd.sub_group) - ] = cmd.app - - def get_local_registered_commands(self) -> ValuesView[AppCommand]: - return ChatCommandHandler.built_register.values() - - async def __get_existing_commands(self): - """|coro| - - Get AppCommand objects for all commands registered to discord. - """ - try: - self._api_commands = await self.get_commands() - - except ForbiddenError: - logging.error("Cannot retrieve slash commands, skipping...") - return - - async def __remove_unused_commands(self): - """|coro| - - Remove commands that are registered by discord but not in use - by the current client - """ - local_registered_commands = self.get_local_registered_commands() - - def should_be_removed(target: AppCommand) -> bool: - for reg_cmd in local_registered_commands: - # Commands have endpoints based on their `name` amd `guild_id`. Other - # parameters can be updated instead of deleting and re-registering the - # command. - if ( - target.name == reg_cmd.name - and target.guild_id == reg_cmd.guild_id - ): - return False - return True - - # NOTE: Cannot be generator since it can't be consumed due to lines 743-745 - to_remove = [*filter(should_be_removed, self._api_commands)] - - await gather( - *map( - lambda cmd: self.remove_command(cmd), - to_remove, - ) - ) - - self._api_commands = list( - filter(lambda cmd: cmd not in to_remove, self._api_commands) - ) - - async def __add_commands(self): - """|coro| - - Add all new commands which have been registered by the decorator to Discord. - - .. code-block:: - - Because commands have unique names within a type and scope, we treat POST - requests for new commands as upserts. That means making a new command with - an already-used name for your application will update the existing command. - ``_ - - Therefore, we don't need to use a separate loop for updating and adding - commands. - """ - local_registered_commands = self.get_local_registered_commands() - - def should_be_updated_or_uploaded(target): - for command in self._api_commands: - if target == command: - return False - return True - - changed_commands = filter( - should_be_updated_or_uploaded, local_registered_commands + ChatCommandHandler.register[ + _hash_app_command_params( + cmd, + guild_id, + app_command_type, + group, + sub_group ) + ] = interactable - for command in changed_commands: - await self.add_command(command) - - async def initialize(self): - """|coro| - - Call methods of this class to refresh all app commands - """ - if ChatCommandHandler.has_been_initialized: - # Only first shard should be initialized. - return - - ChatCommandHandler.has_been_initialized = True - - self.__build_local_commands() - await self.__get_existing_commands() - await self.__remove_unused_commands() - await self.__add_commands() - - -def _hash_app_command( - command: AppCommand, - group: Optional[str], - sub_group: Optional[str] -) -> int: - """ - See :func:`~pincer.commands.commands._hash_app_command_params` for information. - """ - return _hash_app_command_params( - command.name, - command.guild_id, - command.type, - group, - sub_group - ) - - -def _hash_app_command_params( - name: str, - guild_id: Union[Snowflake, None, MISSING], - app_command_type: AppCommandType, - group: Optional[str], - sub_group: Optional[str] -) -> int: - """ - The group layout in Pincer is very different than what discord has on their docs. - You can think of the Pincer group layout like this: - - name: The name of the function that is being called. - - group: The :class:`~pincer.commands.groups.Group` object that this function is - using. - sub_option: The :class:`~pincer.commands.groups.Subgroup` object that this - functions is using. - - Abstracting away this part of the Discord API allows for a much cleaner - transformation between what users want to input and what commands Discord - expects. - - Parameters - ---------- - name : str - The name of the function for the command - guild_id : Union[:class:`~pincer.utils.snowflake.Snowflake`, None, MISSING] - The ID of a guild, None, or MISSING. - app_command_type : :class:`~pincer.objects.app.command_types.AppCommandType` - The app command type of the command. NOT THE OPTION TYPE. - group : str - The highest level of organization the command is it. This should always be the - name of the base command. :data:`None` or :data:`MISSING` if not there. - sub_option : str - The name of the group that holds the lowest level of options. :data:`None` or - :data:`MISSING` if not there. - """ - return hash((name, guild_id, app_command_type, group, sub_group)) + return interactable diff --git a/pincer/commands/components/action_row.py b/pincer/commands/components/action_row.py index c44bade0..a3d26783 100644 --- a/pincer/commands/components/action_row.py +++ b/pincer/commands/components/action_row.py @@ -5,7 +5,8 @@ from typing import TYPE_CHECKING -from ...objects.message.component import MessageComponent +from ._component import _Component +from ...objects.app.command import InteractableStructure from ...utils.api_object import APIObject if TYPE_CHECKING: @@ -23,11 +24,13 @@ class ActionRow(APIObject): :class:`~pincer.objects.message.select_menu.SelectMenu` """ - def __init__(self, *components: MessageComponent): + def __init__(self, *components: InteractableStructure[_Component]): self.components = components def to_dict(self) -> Dict: return { "type": 1, - "components": [component.to_dict() for component in self.components] + "components": [ + component.metadata.to_dict() for component in self.components + ] } diff --git a/pincer/commands/components/button.py b/pincer/commands/components/button.py index 28458c89..48bc9725 100644 --- a/pincer/commands/components/button.py +++ b/pincer/commands/components/button.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from enum import IntEnum -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING from ._component import _Component from ...utils.types import MISSING @@ -75,14 +75,10 @@ class Button(_Component): disabled: APINullable[bool] = False type: int = 2 - _func: Optional[Callable] = None def __post_init__(self): self.type = 2 - def __call__(self, *args, **kwargs): - return self._func(*args, **kwargs) - @dataclass(repr=False) class LinkButton(_Component): diff --git a/pincer/commands/components/component_handler.py b/pincer/commands/components/component_handler.py index 982ae341..21613d3a 100644 --- a/pincer/commands/components/component_handler.py +++ b/pincer/commands/components/component_handler.py @@ -1,9 +1,12 @@ # Copyright Pincer 2021-Present # Full MIT License can be found in `LICENSE` at the project root. -from typing import Callable, Dict +from typing import Dict + +from ._component import _Component from ...utils.types import Singleton +from ...objects.app.command import InteractableStructure class ComponentHandler(metaclass=Singleton): @@ -15,7 +18,4 @@ class ComponentHandler(metaclass=Singleton): Dictionary of registered buttons. """ - register: Dict[str, Callable] = {} - - def register_id(self, _id: str, func: Callable): - self.register[_id] = func + register: Dict[str, InteractableStructure[_Component]] = {} diff --git a/pincer/commands/components/decorators.py b/pincer/commands/components/decorators.py index 5ddd868c..e9bd08cc 100644 --- a/pincer/commands/components/decorators.py +++ b/pincer/commands/components/decorators.py @@ -5,10 +5,12 @@ from inspect import iscoroutinefunction from typing import List + from .button import Button, ButtonStyle from .select_menu import SelectMenu, SelectOption from .component_handler import ComponentHandler from ...exceptions import CommandIsNotCoroutine +from ...objects.app.command import InteractableStructure from ...objects.message.emoji import Emoji from ...utils.conversion import remove_none @@ -75,27 +77,26 @@ def wrap(custom_id, func) -> Button: if custom_id is None: custom_id = func.__name__ - ComponentHandler().register_id(custom_id, func) - - button = Button( - # Hack to not override defaults in button class - **remove_none( - { - "custom_id": custom_id, - "style": style, - "label": label, - "disabled": disabled, - "emoji": emoji, - "url": url, - "_func": func - } + interactable = InteractableStructure( + call=func, + metadata=Button( + # Hack to not override defaults in button class + **remove_none( + { + "custom_id": custom_id, + "style": style, + "label": label, + "disabled": disabled, + "emoji": emoji, + "url": url, + } + ) ) ) - button.func = func - button.__call__ = partial(func) + ComponentHandler.register[interactable.metadata.custom_id] = interactable - return button + return interactable return partial(wrap, custom_id) @@ -149,24 +150,26 @@ def wrap(custom_id, func) -> SelectMenu: if custom_id is None: custom_id = func.__name__ - ComponentHandler().register_id(custom_id, func) - - menu = SelectMenu( - # Hack to not override defaults in button class - **remove_none( - { - "custom_id": custom_id, - "options": options, - "placeholder": placeholder, - "min_values": min_values, - "max_values": max_values, - "disabled": disabled, - "_func": func - } + interactable = InteractableStructure( + call=func, + metadata=SelectMenu( + # Hack to not override defaults in button class + **remove_none( + { + "custom_id": custom_id, + "options": options, + "placeholder": placeholder, + "min_values": min_values, + "max_values": max_values, + "disabled": disabled, + } + ) ) ) - return menu + ComponentHandler.register[interactable.metadata.custom_id] = interactable + + return interactable if func is None: return partial(wrap, custom_id) diff --git a/pincer/commands/components/select_menu.py b/pincer/commands/components/select_menu.py index 5cec6158..c6f89d64 100644 --- a/pincer/commands/components/select_menu.py +++ b/pincer/commands/components/select_menu.py @@ -5,7 +5,7 @@ from copy import copy from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Optional from ._component import _Component from ...utils.api_object import APIObject @@ -81,14 +81,10 @@ class SelectMenu(_Component): disabled: APINullable[bool] = False type: int = 3 - _func: Optional[Callable] = None def __post_init__(self): self.type = 3 - def __call__(self, *args, **kwargs): - return self._func(*args, **kwargs) - def with_options(self, *options: SelectOption) -> SelectMenu: """ Sets the ``options`` parameter to \\*options and returns a new diff --git a/pincer/commands/groups.py b/pincer/commands/groups.py index a0e6d3dc..bb7d8d07 100644 --- a/pincer/commands/groups.py +++ b/pincer/commands/groups.py @@ -29,8 +29,8 @@ async def a_very_cool_command(): name : str The name of the command group. description : Optional[:class:`str`] - The description of the command. This has to be sent to Discord but it does - nothing so it is optional. + The description of the command. This has to be sent to Discord, but it does + nothing, so it is optional. """ name: str description: Optional[str] = None @@ -42,7 +42,7 @@ def __hash__(self) -> int: @dataclass class Subgroup: """ - A subgroup of commands. This allows you to create subcommands inside of a + A subgroup of commands. This allows you to create subcommands inside a subcommand-group. .. code-block:: python @@ -66,8 +66,8 @@ async def a_very_cool_command(): parent : :class:`~pincer.commands.groups.Group` The parent group of this command. description : Optional[:class:`str`] - The description of the command. This has to be sent to Discord but it does - nothing so it is optional. + The description of the command. This has to be sent to Discord, but it does + nothing, so it is optional. """ name: str parent: Group diff --git a/pincer/commands/interactable.py b/pincer/commands/interactable.py new file mode 100644 index 00000000..050ec2f7 --- /dev/null +++ b/pincer/commands/interactable.py @@ -0,0 +1,55 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from __future__ import annotations + +from collections import ChainMap + +from .. import client as _client +from .chat_command_handler import ChatCommandHandler +from .components.component_handler import ComponentHandler +from ..objects.app.command import AppCommand, InteractableStructure + + +INTERACTION_REGISTERS = ChainMap(ChatCommandHandler.register, ComponentHandler.register) + + +class Interactable: + """ + Class that can register :class:`~pincer.commands.interactable.PartialInteractable` + objects. Any class that subclasses this class can register Application Commands and + Message Components. + PartialInteractable objects are registered by running the register function and + setting an attribute of the client to the result. + """ + + def __init__(self): + for value in vars(type(self)).values(): + if isinstance(value, InteractableStructure): + value.manager = self + + def __del__(self): + self.unassign() + + def unassign(self): + """ + Removes this objects loaded commands from ChatCommandHandler and + ComponentHandler and removes loaded events from the client. + """ + for value in vars(type(self)).values(): + if ( + isinstance(value, InteractableStructure) + and isinstance(value.metadata, AppCommand) + ): + for key, _value in INTERACTION_REGISTERS.items(): + if value is _value: + INTERACTION_REGISTERS.pop(key) + + key = value.call.__name__.lower() + + event_or_list = _client._events.get(key) + if isinstance(event_or_list, list): + if value in event_or_list: + event_or_list.remove(value) + else: + _client._events.pop(key, None) diff --git a/pincer/core/gateway.py b/pincer/core/gateway.py index 7a86f6fd..76d0f21c 100644 --- a/pincer/core/gateway.py +++ b/pincer/core/gateway.py @@ -136,7 +136,7 @@ def __init__( # `ClientWebSocketResponse` is a parent class. self.__socket: Optional[ClientWebSocketResponse] = None - # Buffer used to store information in transport conpression. + # Buffer used to store information in transport compression. self.__buffer = bytearray() # The gateway can be disconnected from Discord. This variable stores if the @@ -228,7 +228,7 @@ async def start_loop(self): ) await sleep(15) - _log.debug("%s Starting envent loop...", self.shard_key) + _log.debug("%s Starting event loop...", self.shard_key) await self.event_loop() async def event_loop(self): @@ -248,7 +248,7 @@ async def event_loop(self): # The loop is broken when the gateway stops receiving messages. # The "error" op codes are in `self.__close_codes`. The rest of the - # close codes are unknown issues (such as a unintended disconnect) so the + # close codes are unknown issues (such as an unintended disconnect) so the # client should reconnect to the gateway. err = self.__close_codes.get(self.__socket.close_code) @@ -266,7 +266,7 @@ async def handle_data(self, data: Dict[Any]): """|coro| Method is run when a payload is received from the gateway. The message is expected to already have been decompressed. - Handling the opcode is forked to the background so they aren't blocking. + Handling the opcode is forked to the background, so they aren't blocking. """ payload = GatewayDispatch.from_string(data) @@ -374,7 +374,7 @@ async def identify_and_handle_hello(self, payload: GatewayDispatch): async def handle_heartbeat(self, payload: GatewayDispatch): """|coro| - Opcode 11 - Heatbeat + Opcode 11 - Heartbeat Track that the heartbeat has been received using shared state (Rustaceans would be very mad) """ @@ -419,7 +419,7 @@ def stop_heartbeat(self): def send_next_heartbeat(self): """ - It is expected to always be waiting for a hearbeat. By canceling that task, + It is expected to always be waiting for a heartbeat. By canceling that task, a heartbeat can be sent. """ self.__wait_for_heartbeat.cancel() diff --git a/pincer/middleware/activity_join_request.py b/pincer/middleware/activity_join_request.py index b5f95963..8aa0ca45 100644 --- a/pincer/middleware/activity_join_request.py +++ b/pincer/middleware/activity_join_request.py @@ -37,7 +37,7 @@ async def activity_join_request_middleware( """ return ( "on_activity_join_request", - User.from_dict(self, payload.data), + User.from_dict(payload.data), ) diff --git a/pincer/middleware/interaction_create.py b/pincer/middleware/interaction_create.py index 29d97bfd..be4137ff 100644 --- a/pincer/middleware/interaction_create.py +++ b/pincer/middleware/interaction_create.py @@ -4,12 +4,13 @@ from __future__ import annotations import logging +from copy import copy from contextlib import suppress from inspect import isasyncgenfunction, _empty -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from ..commands import ChatCommandHandler, ComponentHandler -from ..commands.commands import _hash_app_command_params +from ..commands.chat_command_handler import _hash_app_command_params from ..exceptions import InteractionDoesNotExist from ..objects import ( Interaction, @@ -87,16 +88,17 @@ def get_command_from_registry(interaction: Interaction): ) -def get_call(self: Client, interaction: Interaction): +def get_call(self: Client, interaction: Interaction) -> Optional[Tuple[Coro, Any]]: if interaction.type == InteractionType.APPLICATION_COMMAND: command = get_command_from_registry(interaction) if command is None: return None # Only application commands can be throttled self.throttler.handle(command) - return command.call + return command.call, command.manager elif interaction.type == InteractionType.MESSAGE_COMPONENT: - return ComponentHandler.register.get(interaction.data.custom_id) + command = ComponentHandler.register.get(interaction.data.custom_id) + return command.call, command.manager elif interaction.type == InteractionType.AUTOCOMPLETE: raise NotImplementedError( "Handling for autocomplete is not implemented" @@ -104,7 +106,9 @@ def get_call(self: Client, interaction: Interaction): async def interaction_response_handler( + self: Client, command: Coro, + manager: Any, context: MessageContext, interaction: Interaction, args: List[Any], @@ -125,12 +129,14 @@ async def interaction_response_handler( \\*\\*kwargs : The arguments to be passed to the command. """ - sig, params = get_signature_and_params(command) - if should_pass_ctx(sig, params): + # Prevent args from being mutated unexpectedly + args = copy(args) + + if should_pass_ctx(*get_signature_and_params(command)): args.insert(0, context) if should_pass_cls(command): - args.insert(0, ChatCommandHandler.managers[command.__module__]) + args.insert(0, manager or self) if isasyncgenfunction(command): message = command(*args, **kwargs) @@ -147,7 +153,11 @@ async def interaction_response_handler( async def interaction_handler( - interaction: Interaction, context: MessageContext, command: Coro + self: Client, + interaction: Interaction, + context: MessageContext, + command: Coro, + manager: Any ): """|coro| @@ -207,9 +217,25 @@ def get_options_from_command(options): kwargs = {**defaults, **params} - await interaction_response_handler( - command, context, interaction, args, kwargs - ) + try: + await interaction_response_handler( + self, command, manager, context, interaction, args, kwargs + ) + except Exception as e: + if coro := get_index(self.get_event_coro("on_command_error"), 0): + try: + await interaction_response_handler( + self, + coro.call, + coro.manager, + context, + interaction, + [e, *args], + kwargs, + ) + except Exception as e: + raise e + raise e async def interaction_create_middleware( @@ -237,29 +263,10 @@ async def interaction_create_middleware( ``on_interaction_create`` and an ``Interaction`` """ interaction: Interaction = Interaction.from_dict(payload.data) - - call = get_call(self, interaction) + call, manager = get_call(self, interaction) context = interaction.get_message_context() - try: - await interaction_handler(interaction, context, call) - except Exception as e: - if coro := get_index(self.get_event_coro("on_command_error"), 0): - params = get_signature_and_params(coro)[1] - - # Check if a context or error var has been passed. - if 0 < len(params) < 3: - await interaction_response_handler( - coro, - context, - interaction, - # Always take the error parameter its name. - {params[-1]: e}, - ) - else: - raise e - else: - raise e + await interaction_handler(self, interaction, context, call, manager) return "on_interaction_create", interaction diff --git a/pincer/middleware/thread_list_sync.py b/pincer/middleware/thread_list_sync.py index d2cb29f6..248dd90b 100644 --- a/pincer/middleware/thread_list_sync.py +++ b/pincer/middleware/thread_list_sync.py @@ -16,7 +16,7 @@ async def thread_list_sync( - self: Client, gatewayer: Gateway, payload: GatewayDispatch + self: Client, gateway: Gateway, payload: GatewayDispatch ): """|coro| diff --git a/pincer/middleware/thread_members_update.py b/pincer/middleware/thread_members_update.py index 22af72ea..9f784e67 100644 --- a/pincer/middleware/thread_members_update.py +++ b/pincer/middleware/thread_members_update.py @@ -16,7 +16,7 @@ async def thread_members_update_middleware( - self: Client, gatewayer: Gateway, payload: GatewayDispatch + self: Client, gateway: Gateway, payload: GatewayDispatch ): """|coro| diff --git a/pincer/middleware/thread_update.py b/pincer/middleware/thread_update.py index 9070c309..c463ac32 100644 --- a/pincer/middleware/thread_update.py +++ b/pincer/middleware/thread_update.py @@ -17,7 +17,7 @@ async def thread_update_middleware( - self: Client, gatewayer: Gateway, payload: GatewayDispatch + self: Client, gateway: Gateway, payload: GatewayDispatch ): """|coro| diff --git a/pincer/objects/__init__.py b/pincer/objects/__init__.py index 0e25d8a3..6e84fb5f 100644 --- a/pincer/objects/__init__.py +++ b/pincer/objects/__init__.py @@ -5,7 +5,7 @@ from .app.command import ( AppCommandType, AppCommandOptionType, AppCommandInteractionDataOption, AppCommandOptionChoice, AppCommandOption, AppCommand, - ClientCommandStructure + InteractableStructure ) from .app.intents import Intents from .app.interaction_base import ( @@ -68,6 +68,7 @@ ) from .guild.member import GuildMember, PartialGuildMember, BaseMember from .guild.overwrite import Overwrite +from .guild.permissions import Permissions from .guild.role import RoleTags, Role from .guild.stage import PrivacyLevel, StageInstance from .guild.template import GuildTemplate @@ -104,17 +105,16 @@ from .voice.region import VoiceRegion __all__ = ( - "", "Activity", "ActivityAssets", "ActivityButton", - "ActivityEmoji", "ActivityFlags", "ActivityParty", "ActivitySecrets", - "ActivityTimestamp", "ActivityType", "AllowedMentionTypes", - "AllowedMentions", "AppCommand", "AppCommandInteractionDataOption", - "AppCommandOption", "AppCommandOptionChoice", "AppCommandOptionType", - "AppCommandType", "Application", "Attachment", "AuditEntryInfo", - "AuditLog", "AuditLogChange", "AuditLogEntry", "AuditLogEvent", "Ban", - "BaseMember", "CallbackType", "CategoryChannel", "Channel", - "ChannelMention", "ChannelPinsUpdateEvent", "ChannelType", - "ClientCommandStructure", "ClientStatus", "ComponentType", "Connection", - "DefaultMessageNotificationLevel", "DefaultThrottleHandler", + "Activity", "ActivityAssets", "ActivityButton", "ActivityEmoji", + "ActivityFlags", "ActivityParty", "ActivitySecrets", "ActivityTimestamp", + "ActivityType", "AllowedMentionTypes", "AllowedMentions", "AppCommand", + "AppCommandInteractionDataOption", "AppCommandOption", + "AppCommandOptionChoice", "AppCommandOptionType", "AppCommandType", + "Application", "Attachment", "AuditEntryInfo", "AuditLog", + "AuditLogChange", "AuditLogEntry", "AuditLogEvent", "Ban", "BaseMember", + "CallbackType", "CategoryChannel", "Channel", "ChannelMention", + "ChannelPinsUpdateEvent", "ChannelType", "ClientStatus", "ComponentType", + "Connection", "DefaultMessageNotificationLevel", "DefaultThrottleHandler", "DiscordError", "Embed", "EmbedAuthor", "EmbedField", "EmbedFooter", "EmbedImage", "EmbedProvider", "EmbedThumbnail", "EmbedVideo", "Emoji", "ExplicitContentFilterLevel", "File", "FollowedChannel", "Guild", @@ -125,25 +125,25 @@ "GuildRoleDeleteEvent", "GuildRoleUpdateEvent", "GuildStickersUpdateEvent", "GuildTemplate", "GuildWidget", "HelloEvent", "Identify", "Integration", "IntegrationAccount", "IntegrationApplication", "IntegrationDeleteEvent", - "IntegrationExpireBehavior", "Intents", "Interaction", "InteractionData", - "InteractionFlags", "InteractionType", "Invite", "InviteCreateEvent", - "InviteDeleteEvent", "InviteStageInstance", "InviteTargetType", "MFALevel", - "Mentionable", "Message", "MessageActivity", "MessageActivityType", - "MessageComponent", "MessageContext", "MessageDeleteBulkEvent", - "MessageDeleteEvent", "MessageFlags", "MessageInteraction", - "MessageReactionAddEvent", "MessageReactionRemoveAllEvent", - "MessageReactionRemoveEmojiEvent", "MessageReactionRemoveEvent", - "MessageReference", "MessageType", "NewsChannel", "Overwrite", - "PartialGuildMember", "PremiumTier", "PremiumTypes", "PresenceUpdateEvent", - "PrivacyLevel", "Reaction", "ReadyEvent", "RequestGuildMembers", - "ResolvedData", "Resume", "Role", "RoleTags", "SessionStartLimit", - "StageInstance", "StatusType", "Sticker", "StickerFormatType", - "StickerItem", "StickerPack", "StickerType", "SystemChannelFlags", - "TextChannel", "ThreadListSyncEvent", "ThreadMember", - "ThreadMembersUpdateEvent", "ThreadMetadata", "ThrottleInterface", - "ThrottleScope", "TypingStartEvent", "UpdatePresence", "UpdateVoiceState", - "User", "UserMessage", "VerificationLevel", "VisibilityType", - "VoiceChannel", "VoiceRegion", "VoiceServerUpdateEvent", "VoiceState", - "Webhook", "WebhookType", "WebhooksUpdateEvent", "WelcomeScreen", - "WelcomeScreenChannel" + "IntegrationExpireBehavior", "Intents", "InteractableStructure", + "Interaction", "InteractionData", "InteractionFlags", "InteractionType", + "Invite", "InviteCreateEvent", "InviteDeleteEvent", "InviteStageInstance", + "InviteTargetType", "MFALevel", "Mentionable", "Message", + "MessageActivity", "MessageActivityType", "MessageComponent", + "MessageContext", "MessageDeleteBulkEvent", "MessageDeleteEvent", + "MessageFlags", "MessageInteraction", "MessageReactionAddEvent", + "MessageReactionRemoveAllEvent", "MessageReactionRemoveEmojiEvent", + "MessageReactionRemoveEvent", "MessageReference", "MessageType", + "NewsChannel", "Overwrite", "PartialGuildMember", "Permissions", + "PremiumTier", "PremiumTypes", "PresenceUpdateEvent", "PrivacyLevel", + "Reaction", "ReadyEvent", "RequestGuildMembers", "ResolvedData", "Resume", + "Role", "RoleTags", "SessionStartLimit", "StageInstance", "StatusType", + "Sticker", "StickerFormatType", "StickerItem", "StickerPack", + "StickerType", "SystemChannelFlags", "TextChannel", "ThreadListSyncEvent", + "ThreadMember", "ThreadMembersUpdateEvent", "ThreadMetadata", + "ThrottleInterface", "ThrottleScope", "TypingStartEvent", "UpdatePresence", + "UpdateVoiceState", "User", "UserMessage", "VerificationLevel", + "VisibilityType", "VoiceChannel", "VoiceRegion", "VoiceServerUpdateEvent", + "VoiceState", "Webhook", "WebhookType", "WebhooksUpdateEvent", + "WelcomeScreen", "WelcomeScreenChannel" ) diff --git a/pincer/objects/app/__init__.py b/pincer/objects/app/__init__.py index 32938851..d7e96006 100644 --- a/pincer/objects/app/__init__.py +++ b/pincer/objects/app/__init__.py @@ -5,7 +5,7 @@ from .command import ( AppCommandInteractionDataOption, AppCommandOptionChoice, AppCommandOption, AppCommand, - ClientCommandStructure + InteractableStructure ) from .command_types import AppCommandType, AppCommandOptionType from .intents import Intents @@ -21,8 +21,8 @@ __all__ = ( "AppCommand", "AppCommandInteractionDataOption", "AppCommandOption", "AppCommandOptionChoice", "AppCommandOptionType", - "AppCommandType", "Application", "CallbackType", "ClientCommandStructure", - "DefaultThrottleHandler", "Intents", "Interaction", "InteractionData", + "AppCommandType", "Application", "CallbackType", "DefaultThrottleHandler", + "Intents", "InteractableStructure", "Interaction", "InteractionData", "InteractionFlags", "InteractionType", "Mentionable", "MessageInteraction", "ResolvedData", "SessionStartLimit", "ThrottleInterface", "ThrottleScope" ) diff --git a/pincer/objects/app/command.py b/pincer/objects/app/command.py index c34128e8..3c3b9046 100644 --- a/pincer/objects/app/command.py +++ b/pincer/objects/app/command.py @@ -3,13 +3,12 @@ from __future__ import annotations -from dataclasses import dataclass -from typing import List, Union, TYPE_CHECKING - - -from pincer.commands.groups import Group, Subgroup +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable, Generic, List, Optional, Union, TYPE_CHECKING, TypeVar from .command_types import AppCommandOptionType, AppCommandType +from ..app.throttle_scope import ThrottleScope +from ...commands.groups import Group, Subgroup from ...objects.guild.channel import ChannelType from ...utils.api_object import APIObject, GuildProperty from ...utils.snowflake import Snowflake @@ -18,8 +17,8 @@ if TYPE_CHECKING: from ...utils.types import APINullable - from ..app.throttle_scope import ThrottleScope +T = TypeVar("T") @dataclass(repr=False) class AppCommandInteractionDataOption(APIObject): @@ -58,6 +57,10 @@ class AppCommandOptionChoice(APIObject): name: str value: choice_value_types + def __post_init__(self): + # APIObject __post_init_ causes issues by converting `value` to a string + self.name = str(self.name) + @dataclass(repr=False) class AppCommandOption(APIObject): @@ -142,8 +145,8 @@ def __post_init__(self): if self.options is MISSING and self.type is AppCommandType.MESSAGE: self.options = [] - def __eq__(self, other: Union[AppCommand, ClientCommandStructure]): - if isinstance(other, ClientCommandStructure): + def __eq__(self, other: Union[AppCommand, InteractableStructure]): + if isinstance(other, InteractableStructure): other = other.app # `description` and `options` are tested for equality with a custom check @@ -184,29 +187,41 @@ def add_option(self, option: AppCommandOption): @dataclass(repr=False) -class ClientCommandStructure: +class InteractableStructure(Generic[T]): """Represents the structure of how the client saves the existing - commands in the register. + commands to registers. This is generic over Application Commands, + Message Components, and Autocomplete. Attributes ---------- - app: :class:`~pincer.objects.app.command.AppCommand` - The command application. call: :class:`~pincer.utils.types.Coro` The coroutine which should be called when the command gets executed. + metadata: T + The metadata for this command. |default| :data:`None` + manager : Optional[Any] + The manager for this interactable. |default| :data:`None` + extensions: List[Callable[..., Awaitable[bool]]] + List of extensions for this command. |default| :data:`[]` cooldown: :class:`int` - Amount of times for cooldown + Amount of times for cooldown |default| :data:`0` cooldown_scale: :class:`float` - Search time for cooldown + Search time for cooldown |default| :data:`60.0` cooldown_scope: :class:`~pincer.objects.app.throttle_scope.ThrottleScope` - The type of cooldown + The type of cooldown |default| :data:`ThrottleScope.USER` """ - app: AppCommand call: Coro - cooldown: int - cooldown_scale: float - cooldown_scope: ThrottleScope + + metadata: Optional[T] = None + manager: Optional[Any] = None + extensions: List[Callable[..., Awaitable[bool]]] = field(default_factory=list) + + cooldown: int = 0 + cooldown_scale: float = 60.0 + cooldown_scope: ThrottleScope = ThrottleScope.USER group: APINullable[Group] = MISSING sub_group: APINullable[Subgroup] = MISSING + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + self.call(*args, **kwargs) diff --git a/pincer/objects/app/intents.py b/pincer/objects/app/intents.py index 2623854d..b1ae6457 100644 --- a/pincer/objects/app/intents.py +++ b/pincer/objects/app/intents.py @@ -3,10 +3,10 @@ from __future__ import annotations -from enum import IntEnum +from enum import IntFlag -class Intents(IntEnum): +class Intents(IntFlag): """Discord client intents. These give your client more permissions. @@ -67,21 +67,10 @@ class Intents(IntEnum): DIRECT_MESSAGE_REACTIONS = 1 << 13 DIRECT_MESSAGE_TYPING = 1 << 14 - @staticmethod - def all() -> int: + @classmethod + def all(cls) -> Intents: """ :class:`~pincer.objects.app.intents.Intents`: Method of all intents """ - res = 0 - - for intent in list(map(lambda itm: itm.value, Intents)): - res |= intent - - return res - - def __repr__(self): - return f"Intents({self.name})" - - def __str__(self) -> str: - return self.name.lower().replace("_", " ") + return cls(sum(cls)) diff --git a/pincer/objects/app/interaction_flags.py b/pincer/objects/app/interaction_flags.py index 5a7ae3f4..5492209f 100644 --- a/pincer/objects/app/interaction_flags.py +++ b/pincer/objects/app/interaction_flags.py @@ -1,10 +1,10 @@ # Copyright Pincer 2021-Present # Full MIT License can be found in `LICENSE` at the project root. -from enum import IntEnum +from enum import IntFlag -class InteractionFlags(IntEnum): +class InteractionFlags(IntFlag): """ Attributes diff --git a/pincer/objects/app/interactions.py b/pincer/objects/app/interactions.py index 5d54b87a..08b5fe51 100644 --- a/pincer/objects/app/interactions.py +++ b/pincer/objects/app/interactions.py @@ -198,10 +198,7 @@ def return_type( data : Dict[:class:`~pincer.utils.types.Snowflake`, Any] Resolved data to search through. """ - if data: - return data[option.value] - - return None + return data[option.value] if data else None def get_message_context(self): return MessageContext( @@ -382,7 +379,7 @@ async def _base_reply( async def reply(self, message: MessageConvertable) -> UserMessage: """|coro| - Sends a reply to a interaction. + Sends a reply to an interaction. """ return await self._base_reply(message, CallbackType.MESSAGE, False) diff --git a/pincer/objects/app/throttling.py b/pincer/objects/app/throttling.py index e35cda2e..31eeadfa 100644 --- a/pincer/objects/app/throttling.py +++ b/pincer/objects/app/throttling.py @@ -3,10 +3,10 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, DefaultDict from .throttle_scope import ThrottleScope -from ..app.command import ClientCommandStructure +from ..app.command import InteractableStructure from ...exceptions import CommandCooldownError from ...utils.slidingwindow import SlidingWindow @@ -18,16 +18,16 @@ class ThrottleInterface(ABC): """An ABC for throttling.""" - throttle: Dict[Coro, Dict[Optional[str], SlidingWindow]] = {} + throttle: DefaultDict[Coro, Dict[Optional[str], SlidingWindow]] = DefaultDict(dict) @staticmethod @abstractmethod - def handle(command: ClientCommandStructure, **kwargs): + def handle(command: InteractableStructure, **kwargs): raise NotImplementedError class DefaultThrottleHandler(ThrottleInterface, ABC): - """The default throttlehandler based off the + """The default throttle-handler based off the :class:`~pincer.objects.app.throttling.ThrottleInterface` ABC """ __throttle_scopes = { @@ -38,7 +38,7 @@ class DefaultThrottleHandler(ThrottleInterface, ABC): } @staticmethod - def get_key_from_scope(command: ClientCommandStructure) -> Optional[int]: + def get_key_from_scope(command: InteractableStructure) -> Optional[int]: """Retrieve the appropriate key from the context through the throttle scope. @@ -50,7 +50,7 @@ def get_key_from_scope(command: ClientCommandStructure) -> Optional[int]: Returns ------- Optional[:class:`int`] - The throttlescope enum + The throttle-scope enum """ scope = DefaultThrottleHandler.__throttle_scopes[command.cooldown_scope] @@ -65,12 +65,12 @@ def get_key_from_scope(command: ClientCommandStructure) -> Optional[int]: return last_obj @staticmethod - def init_throttler(command: ClientCommandStructure, throttle_key: Optional[int]): + def init_throttler(command: InteractableStructure, throttle_key: Optional[int]): DefaultThrottleHandler.throttle[command.call][throttle_key] \ = SlidingWindow(command.cooldown, command.cooldown_scale) @staticmethod - def handle(command: ClientCommandStructure, **kwargs): + def handle(command: InteractableStructure, **kwargs): if command.cooldown <= 0: return diff --git a/pincer/objects/events/presence.py b/pincer/objects/events/presence.py index 0b1b296a..a3f2947e 100644 --- a/pincer/objects/events/presence.py +++ b/pincer/objects/events/presence.py @@ -199,7 +199,7 @@ class Activity(APIObject): secrets: APINullable[:class:`~pincer.objects.events.presence.ActivitySecrets`] Secrets for Rich Presence joining and spectating instance: APINullable[:class:`bool`] - "nether or not the activity is an instanced game session + whether or not the activity is an instanced game session flags: APINullable[:class:`~pincer.objects.events.presence.ActivityFlags`] Activity flags ``OR``\\d together, describes what the payload includes diff --git a/pincer/objects/guild/__init__.py b/pincer/objects/guild/__init__.py index 3d1aaf7e..e16421fb 100644 --- a/pincer/objects/guild/__init__.py +++ b/pincer/objects/guild/__init__.py @@ -21,8 +21,9 @@ ) from .member import GuildMember, PartialGuildMember, BaseMember from .overwrite import Overwrite +from .permissions import Permissions from .role import RoleTags, Role -from .scheduled_events import GuildScheduledEventEntityType, EventStatus, ScheduledEvent +from .scheduled_events import GuildScheduledEventEntityType, GuildScheduledEventUser, EventStatus, ScheduledEvent from .stage import PrivacyLevel, StageInstance from .template import GuildTemplate from .thread import ThreadMetadata, ThreadMember @@ -32,14 +33,15 @@ __all__ = ( - "AuditEntryInfo", "AuditLog", "AuditLogChange", "AuditLogEntry", - "AuditLogEvent", "Ban", "BaseMember", "CategoryChannel", "Channel", - "ChannelMention", "ChannelType", "DefaultMessageNotificationLevel", - "EventStatus", "ExplicitContentFilterLevel", "FollowedChannel", "Guild", - "GuildFeature", "GuildMember", "GuildNSFWLevel", - "GuildScheduledEventEntityType", "GuildTemplate", "GuildWidget", "Invite", - "InviteStageInstance", "InviteTargetType", "MFALevel", "NewsChannel", - "Overwrite", "PartialGuildMember", "PremiumTier", "PrivacyLevel", "Role", + "", "AuditEntryInfo", "AuditLog", "AuditLogChange", + "AuditLogEntry", "AuditLogEvent", "Ban", "BaseMember", "CategoryChannel", + "Channel", "ChannelMention", "ChannelType", + "DefaultMessageNotificationLevel", "EventStatus", + "ExplicitContentFilterLevel", "FollowedChannel", "Guild", "GuildFeature", + "GuildMember", "GuildNSFWLevel", "GuildScheduledEventEntityType", + "GuildTemplate", "GuildWidget", "Invite", "InviteStageInstance", + "InviteTargetType", "MFALevel", "NewsChannel", "Overwrite", + "PartialGuildMember", "Permissions", "PremiumTier", "PrivacyLevel", "Role", "RoleTags", "ScheduledEvent", "StageInstance", "SystemChannelFlags", "TextChannel", "ThreadMember", "ThreadMetadata", "UnavailableGuild", "VerificationLevel", "VoiceChannel", "Webhook", "WebhookType", diff --git a/pincer/objects/guild/channel.py b/pincer/objects/guild/channel.py index d1664274..10bc7e81 100644 --- a/pincer/objects/guild/channel.py +++ b/pincer/objects/guild/channel.py @@ -6,18 +6,18 @@ from asyncio import sleep, ensure_future from dataclasses import dataclass from enum import IntEnum -from urllib.parse import urlencode from typing import AsyncIterator, overload, TYPE_CHECKING from .invite import Invite, InviteTargetType from ..message.user_message import UserMessage from ..._config import GatewayConfig +from ...utils.api_data import APIDataGen from ...utils.api_object import APIObject, GuildProperty from ...utils.convert_message import convert_message from ...utils.types import MISSING if TYPE_CHECKING: - from typing import AsyncGenerator, Dict, List, Optional, Union + from typing import Dict, List, Optional, Union from .member import GuildMember from .overwrite import Overwrite @@ -243,7 +243,7 @@ async def edit(self, reason: Optional[str] = None, **kwargs): Parameters ---------- reason Optional[:class:`str`] - The reason of the channel delete. + The reason of the channel edit. \\*\\*kwargs : The keyword arguments to edit the channel with. @@ -289,7 +289,7 @@ async def edit_permissions( type: :class:`int` 0 for a role or 1 for a member. reason: Optional[:class:`str`] - The reason of the channel delete. + The reason of the channel permission edit. """ await self._http.put( f"channels/{self.id}/permissions/{overwrite.id}", @@ -309,7 +309,7 @@ async def delete_permission( overwrite: :class:`~pincer.objects.guild.overwrite.Overwrite` The overwrite object. reason: Optional[:class:`str`] - The reason of the channel delete. + The reason of the channel permission delete. """ await self._http.delete( f"channels/{self.id}/permissions/{overwrite.id}", @@ -351,7 +351,7 @@ async def trigger_typing_indicator(self): """ await self._http.post(f"channels/{self.id}/typing") - async def get_pinned_messages(self) -> AsyncIterator[UserMessage]: + def get_pinned_messages(self) -> APIDataGen[UserMessage]: """|coro| Fetches all pinned messages in the channel. Returns an iterator of pinned messages. @@ -361,9 +361,10 @@ async def get_pinned_messages(self) -> AsyncIterator[UserMessage]: :class:`AsyncIterator[:class:`~pincer.objects.guild.message.UserMessage`]` An iterator of pinned messages. """ - data = await self._http.get(f"channels/{self.id}/pins") - for message in data: - yield UserMessage.from_dict(message) + return APIDataGen( + UserMessage, + self._http.get(f"channels/{self.id}/pins") + ) async def pin_message( self, message: UserMessage, reason: Optional[str] = None @@ -371,6 +372,11 @@ async def pin_message( """|coro| Pin a message in a channel. Requires the ``MANAGE_MESSAGES`` permission. The maximum number of pinned messages is ``50``. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason of the channel message pin. """ await self._http.put( f"channels/{self.id}/pins/{message.id}", @@ -382,6 +388,11 @@ async def unpin_message( ): """|coro| Unpin a message in a channel. Requires the ``MANAGE_MESSAGES`` permission. + + Parameters + ---------- + reason: Optional[:class:`str`] + The reason of the channel message unpin. """ await self._http.delete( f"channels/{self.id}/pins/{message.id}", @@ -440,7 +451,7 @@ async def bulk_delete_messages( messages: List[:class:`~.pincer.utils.Snowflake`] The list of message IDs to delete (2-100). reason: Optional[:class:`str`] - The reason of the channel delete. + The reason of the channel bulk delete. """ await self._http.post( f"channels/{self.id}/messages/bulk_delete", @@ -472,7 +483,8 @@ async def delete( headers={"X-Audit-Log-Reason": reason}, ) - async def __post_send_handler(self, message: UserMessage): + @staticmethod + async def __post_send_handler(message: UserMessage): """Process a message after it was sent. Parameters @@ -519,7 +531,7 @@ async def send(self, message: Union[Embed, Message, str]) -> UserMessage: self.__post_sent(msg) return msg - async def get_webhooks(self) -> AsyncGenerator[Webhook, None]: + def get_webhooks(self) -> APIDataGen[Webhook]: """|coro| Get all webhooks in the channel. Requires the ``MANAGE_WEBHOOKS`` permission. @@ -528,11 +540,12 @@ async def get_webhooks(self) -> AsyncGenerator[Webhook, None]: ------- AsyncGenerator[:class:`~.pincer.objects.guild.webhook.Webhook`, None] """ - data = await self._http.get(f"channels/{self.id}/webhooks") - for webhook_data in data: - yield Webhook.from_dict(webhook_data) + return APIDataGen( + Webhook, + self._http.get(f"channels/{self.id}/webhooks") + ) - async def get_invites(self) -> AsyncIterator[Invite]: + def get_invites(self) -> APIDataGen[Invite]: """|coro| Fetches all the invite objects for the channel. Only usable for guild channels. Requires the ``MANAGE_CHANNELS`` permission. @@ -542,9 +555,10 @@ async def get_invites(self) -> AsyncIterator[Invite]: AsyncIterator[:class:`~pincer.objects.guild.invite.Invite`] Invites iterator. """ - data = await self._http.get(f"channels/{self.id}/invites") - for invite in data: - yield Invite.from_dict(invite) + return APIDataGen( + Invite, + self._http.get(f"channels/{self.id}/invites") + ) async def create_invite( self, @@ -779,6 +793,57 @@ async def fetch_message(self, message_id: int) -> UserMessage: await self._http.get(f"channels/{self.id}/messages/{message_id}") ) + async def history( + self, limit: int = 50, + before: Optional[Union[int, Snowflake]] = None, + after: Optional[Union[int, Snowflake]] = None, + around: Optional[Union[int, Snowflake]] = None, + ) -> AsyncIterator[UserMessage]: + """|coro| + Returns a list of messages in this channel. + + Parameters + ---------- + around : Optional[Union[:class:`int`, :class:`Snowflake`]] + The message ID to look around. + after : Optional[Union[:class:`int`, :class:`Snowflake`]] + The message ID to look after. + before : Optional[Union[:class:`int`, :class:`Snowflake`]] + The message ID to look before. + limit : Optional[Union[:class:`int`, :class:`Snowflake`]] + The maximum number of messages to return. + + Returns + ------- + AsyncIterator[:class:`~pincer.objects.message.user_message.UserMessage`] + An iterator of messages. + """ + + if limit is None: + limit = float('inf') + + while limit > 0: + retrieve = min(limit, 100) + + raw_messages = await self._http.get( + f'/channels/{self.id}/messages', + params={ + 'limit': retrieve, + 'before': before, + 'after': after, + 'around': around, + } + ) + + if not raw_messages: + break + + for _message in raw_messages: + yield UserMessage.from_dict(_message) + + before = raw_messages[-1]['id'] + limit -= retrieve + class VoiceChannel(Channel): """A subclass of ``Channel`` for voice channels with all the same attributes.""" @@ -798,7 +863,7 @@ async def edit( async def edit(self, **kwargs): """|coro| - Edit a text channel with the given keyword arguments. + Edit a voice channel with the given keyword arguments. Parameters ---------- @@ -842,7 +907,7 @@ async def edit( async def edit(self, **kwargs): """|coro| - Edit a text channel with the given keyword arguments. + Edit a news channel with the given keyword arguments. Parameters ---------- @@ -1040,7 +1105,7 @@ async def get_member(self, user: User) -> ThreadMember: await self._http.get(f"channels/{self.id}/thread-members/{user.id}") ) - async def list_members(self) -> AsyncIterator[ThreadMember]: + def list_members(self) -> APIDataGen[ThreadMember]: """|coro| Fetches all the thread members for the thread. Returns an iterator of ThreadMember objects. @@ -1050,9 +1115,10 @@ async def list_members(self) -> AsyncIterator[ThreadMember]: AsyncIterator[:class:`~pincer.objects.channel.ThreadMember`] An iterator of thread members. """ - data = await self._http.get(f"channels/{self.id}/thread-members") - for member in data: - yield ThreadMember.from_dict(member) + return APIDataGen( + ThreadMember, + self._http.get(f"channels/{self.id}/thread-members") + ) class PublicThread(Thread): diff --git a/pincer/objects/guild/features.py b/pincer/objects/guild/features.py index 9d4f7990..27ceadec 100644 --- a/pincer/objects/guild/features.py +++ b/pincer/objects/guild/features.py @@ -3,7 +3,7 @@ from __future__ import annotations -from enum import Enum, auto +from enum import Enum class GuildFeature(Enum): @@ -79,4 +79,8 @@ class GuildFeature(Enum): PRIVATE_THREADS = "PRIVATE_THREADS" NEW_THREAD_PERMISSIONS = "NEW_THREAD_PERMISSIONS" THREADS_ENABLED = "THREADS_ENABLED" + ROLE_ICONS = "ROLE_ICONS" + ANIMATED_BANNER = "ANIMATED_BANNER" + MEMBER_PROFILES = "MEMBER_PROFILES" + ENABLED_DISCOVERABLE_BEFORE = "ENABLED_DISCOVERABLE_BEFORE" diff --git a/pincer/objects/guild/guild.py b/pincer/objects/guild/guild.py index 46121391..0f16384c 100644 --- a/pincer/objects/guild/guild.py +++ b/pincer/objects/guild/guild.py @@ -4,21 +4,25 @@ from __future__ import annotations from dataclasses import dataclass, field +from datetime import datetime from enum import IntEnum -from typing import AsyncGenerator, overload, TYPE_CHECKING +from typing import overload, TYPE_CHECKING from aiohttp import FormData from .channel import Channel, Thread +from .scheduled_events import ScheduledEvent, GuildScheduledEventUser +from ..message.emoji import Emoji from ..message.file import File from ...exceptions import UnavailableGuildError +from ...utils import remove_none +from ...utils.api_data import APIDataGen from ...utils.api_object import APIObject from ...utils.types import MISSING if TYPE_CHECKING: from typing import Any, Dict, List, Optional, Tuple, Union, Generator - from collections.abc import AsyncIterator from .audit_log import AuditLog from .ban import Ban from .channel import ChannelType @@ -27,7 +31,6 @@ from .invite import Invite from .overwrite import Overwrite from .role import Role - from .scheduled_events import ScheduledEvent from .stage import StageInstance from .template import GuildTemplate from .welcome_screen import WelcomeScreen, WelcomeScreenChannel @@ -37,7 +40,6 @@ from ..user.integration import Integration from ..voice.region import VoiceRegion from ..events.presence import PresenceUpdateEvent - from ..message.emoji import Emoji from ..message.sticker import Sticker from ..user.voice_state import VoiceState from ...client import Client @@ -104,7 +106,7 @@ class ExplicitContentFilterLevel(IntEnum): class MFALevel(IntEnum): - """Represents the multi factor authentication level of a guild. + """Represents the multi-factor authentication level of a guild. Attributes ---------- NONE: @@ -619,9 +621,9 @@ async def list_active_threads( return threads, members - async def list_guild_members( + def list_guild_members( self, limit: int = 1, after: int = 0 - ) -> AsyncIterator[GuildMember]: + ) -> APIDataGen[GuildMember]: """|coro| Returns a list of guild member objects that are members of the guild. @@ -638,16 +640,17 @@ async def list_guild_members( the guild member object that is in the guild """ - members = await self._http.get( - f"guilds/{self.id}/members", params={"limit": limit, "after": after} + return APIDataGen( + GuildMember, + self._http.get( + f"guilds/{self.id}/members", + params={"limit": limit, "after": after} + ) ) - for member in members: - yield GuildMember.from_dict(member) - - async def search_guild_members( + def search_guild_members( self, query: str, limit: Optional[int] = None - ) -> AsyncIterator[GuildMember]: + ) -> APIDataGen[GuildMember]: """|coro| Returns a list of guild member objects whose username or nickname starts with a provided string. @@ -665,14 +668,14 @@ async def search_guild_members( guild member objects """ - data = await self._http.get( - f"guilds/{self.id}/members/search", - params={"query": query, "limit": limit}, + return APIDataGen( + GuildMember, + self._http.get( + f"guilds/{self.id}/members/search", + params={"query": query, "limit": limit}, + ) ) - for member in data: - yield GuildMember.from_dict(member) - @overload async def add_guild_member( self, @@ -763,7 +766,7 @@ async def add_guild_member_role( audit log reason |default| :data:`None` """ data = await self._http.put( - f"guilds/{self.id}/{user_id}/roles/{role_id}", + f"guilds/{self.id}/members/{user_id}/roles/{role_id}", headers={"X-Audit-Log-Reason": reason}, ) @@ -783,7 +786,7 @@ async def remove_guild_member_role( audit log reason |default| :data:`None` """ await self._http.delete( - f"guilds/{self.id}/{user_id}/roles/{role_id}", + f"guilds/{self.id}/members/{user_id}/roles/{role_id}", headers={"X-Audit-Log-Reason": reason}, ) @@ -807,21 +810,24 @@ async def remove_guild_member( async def ban( self, - member_id: int, + member: Union[int, GuildMember], reason: str = None, delete_message_days: int = None, ): """ + Ban a guild member. + Parameters ---------- - member_id : :class:`int` - ID of the guild member to ban. + member : Union[:class:`int`, :class:`GuildMember`] + ID or object of the guild member to ban. reason : Optional[:class:`str`] Reason for the kick. delete_message_days : Optional[:class:`int`] Number of days to delete messages for (0-7) """ headers = {} + member_id: int = member if isinstance(member, int) else member.id if reason is not None: headers["X-Audit-Log-Reason"] = reason @@ -835,18 +841,23 @@ async def ban( f"/guilds/{self.id}/bans/{member_id}", data=data, headers=headers ) - async def kick(self, member_id: int, reason: Optional[str] = None): + async def kick( + self, + member: Union[int, GuildMember], + reason: Optional[str] = None + ): """|coro| Kicks a guild member. + Parameters ---------- - member_id : :class:`int` - ID of the guild member to kick. + member : Union[:class:`int`, :class:`GuildMember`] + ID or object of the guild member to kick. reason : Optional[:class:`str`] Reason for the kick. """ - headers = {} + member_id: int = member if isinstance(member, int) else member.id if reason is not None: headers["X-Audit-Log-Reason"] = reason @@ -855,7 +866,7 @@ async def kick(self, member_id: int, reason: Optional[str] = None): f"/guilds/{self.id}/members/{member_id}", headers=headers ) - async def get_roles(self) -> AsyncGenerator[Role, None]: + def get_roles(self) -> APIDataGen[Role]: """|coro| Fetches all the roles in the guild. @@ -864,9 +875,10 @@ async def get_roles(self) -> AsyncGenerator[Role, None]: AsyncGenerator[:class:`~pincer.objects.guild.role.Role`, :data:`None`] An async generator of Role objects. """ - data = await self._http.get(f"guilds/{self.id}/roles") - for role_data in data: - yield Role.from_dict(role_data) + + return APIDataGen( + Role, self._http.get(f"guilds/{self.id}/roles") + ) @overload async def create_role( @@ -925,12 +937,12 @@ async def create_role(self, reason: Optional[str] = None, **kwargs) -> Role: ) ) - async def edit_role_position( + def edit_role_position( self, id: Snowflake, reason: Optional[str] = None, position: Optional[int] = None, - ) -> AsyncGenerator[Role, None]: + ) -> APIDataGen[Role]: """|coro| Edits the position of a role. @@ -948,13 +960,13 @@ async def edit_role_position( AsyncGenerator[:class:`~pincer.objects.guild.role.Role`, :data:`None`] An async generator of all the guild's role objects. """ - data = await self._http.patch( - f"guilds/{self.id}/roles", - data={"id": id, "position": position}, - headers={"X-Audit-Log-Reason": reason}, + return APIDataGen( + Role, self._http.patch( + f"guilds/{self.id}/roles", + data={"id": id, "position": position}, + headers={"X-Audit-Log-Reason": reason}, + ) ) - for role_data in data: - yield Role.from_dict(role_data) @overload async def edit_role( @@ -1034,7 +1046,7 @@ async def delete_role(self, id: Snowflake, reason: Optional[str] = None): headers={"X-Audit-Log-Reason": reason}, ) - async def get_bans(self) -> AsyncGenerator[Ban, None]: + def get_bans(self) -> APIDataGen[Ban]: """|coro| Fetches all the bans in the guild. @@ -1043,9 +1055,11 @@ async def get_bans(self) -> AsyncGenerator[Ban, None]: AsyncGenerator[:class:`~pincer.objects.guild.ban.Ban`, :data:`None`] An async generator of Ban objects. """ - data = await self._http.get(f"guilds/{self.id}/bans") - for ban_data in data: - yield Ban.from_dict(ban_data) + + return APIDataGen( + Ban, + self._http.get(f"guilds/{self.id}/bans") + ) async def get_ban(self, id: Snowflake) -> Ban: """|coro| @@ -1252,7 +1266,7 @@ async def prune( headers={"X-Audit-Log-Reason": reason}, )["pruned"] - async def get_voice_regions(self) -> AsyncGenerator[VoiceRegion, None]: + def get_voice_regions(self) -> APIDataGen[VoiceRegion]: """|coro| Returns an async generator of voice regions. @@ -1261,11 +1275,13 @@ async def get_voice_regions(self) -> AsyncGenerator[VoiceRegion, None]: AsyncGenerator[:class:`~pincer.objects.voice.VoiceRegion`, :data:`None`] An async generator of voice regions. """ - data = await self._http.get(f"guilds/{self.id}/regions") - for voice_region_data in data: - yield VoiceRegion.from_dict(voice_region_data) - async def get_invites(self) -> AsyncGenerator[Invite, None]: + return APIDataGen( + VoiceRegion, + self._http.get(f"guilds/{self.id}/regions") + ) + + def get_invites(self) -> APIDataGen[Invite]: """|coro| Returns an async generator of invites for the guild. Requires the ``MANAGE_GUILD`` permission. @@ -1275,11 +1291,30 @@ async def get_invites(self) -> AsyncGenerator[Invite, None]: AsyncGenerator[:class:`~pincer.objects.invite.Invite`, :data:`None`] An async generator of invites. """ - data = await self._http.get(f"guilds/{self.id}/invites") - for invite_data in data: - yield Invite.from_dict(invite_data) - async def get_integrations(self) -> AsyncIterator[Integration]: + return APIDataGen( + Invite, + self._http.get(f"guilds/{self.id}/invites") + ) + + async def get_invite(self, code: str) -> Invite: + """|coro| + Returns an Invite object for the given invite code. + + Parameters + ---------- + code : :class:`str` + The invite code to get the invite for. + + Returns + ------- + :class:`~pincer.objects.guild.invite.Invite` + The invite object. + """ + data = await self._http.get(f"invite/{code}") + return Invite.from_dict(data) + + def get_integrations(self) -> APIDataGen[Integration]: """|coro| Returns an async generator of integrations for the guild. Requires the ``MANAGE_GUILD`` permission. @@ -1289,9 +1324,11 @@ async def get_integrations(self) -> AsyncIterator[Integration]: AsyncGenerator[:class:`~pincer.objects.integration.Integration`, :data:`None`] An async generator of integrations. """ - data = await self._http.get(f"guilds/{self.id}/integrations") - for integration_data in data: - yield Integration.from_dict(integration_data) + + return APIDataGen( + Integration, + self._http.get(f"guilds/{self.id}/integrations") + ) async def delete_integration( self, integration: Integration, reason: Optional[str] = None @@ -1312,6 +1349,18 @@ async def delete_integration( headers={"X-Audit-Log-Reason": reason}, ) + async def delete_invite(self, code: str): + """|coro| + Deletes an invite. + Requires the ``MANAGE_GUILD`` intent. + + Parameters + ---------- + code : :class:`str` + The code of the invite to delete. + """ + await self._http.delete(f"guilds/{self.id}/invites/{code}") + async def get_widget_settings(self) -> GuildWidget: """|coro| Returns the guild widget settings. @@ -1544,7 +1593,7 @@ async def get_audit_log(self) -> AuditLog: await self._http.get(f"guilds/{self.id}/audit-logs") ) - async def get_emojis(self) -> AsyncGenerator[Emoji, None]: + def get_emojis(self) -> APIDataGen[Emoji]: """|coro| Returns an async generator of the emojis in the guild. @@ -1553,9 +1602,10 @@ async def get_emojis(self) -> AsyncGenerator[Emoji, None]: :class:`~pincer.objects.guild.emoji.Emoji` The emoji object. """ - data = await self._http.get(f"guilds/{self.id}/emojis") - for emoji_data in data: - yield Emoji.from_dict(emoji_data) + return APIDataGen( + Emoji, + self._http.get(f"guilds/{self.id}/emojis") + ) async def get_emoji(self, id: Snowflake) -> Emoji: """|coro| @@ -1669,7 +1719,7 @@ async def delete_emoji( headers={"X-Audit-Log-Reason": reason}, ) - async def get_templates(self) -> AsyncIterator[GuildTemplate]: + def get_templates(self) -> APIDataGen[GuildTemplate]: """|coro| Returns an async generator of the guild templates. @@ -1678,9 +1728,11 @@ async def get_templates(self) -> AsyncIterator[GuildTemplate]: AsyncGenerator[:class:`~pincer.objects.guild.template.GuildTemplate`, :data:`None`] The guild template object. """ - data = await self._http.get(f"guilds/{self.id}/templates") - for template_data in data: - yield GuildTemplate.from_dict(template_data) + + return APIDataGen( + GuildTemplate, + self._http.get(f"guilds/{self.id}/templates") + ) async def create_template( self, name: str, description: Optional[str] = None @@ -1780,7 +1832,7 @@ async def delete_template(self, template: GuildTemplate) -> GuildTemplate: ) return GuildTemplate.from_dict(data) - async def list_stickers(self) -> AsyncIterator[Sticker]: + def list_stickers(self) -> APIDataGen[Sticker]: """|coro| Yields sticker objects for the current guild. Includes ``user`` fields if the bot has the @@ -1792,8 +1844,10 @@ async def list_stickers(self) -> AsyncIterator[Sticker]: a sticker for the current guild """ - for sticker in await self._http.get(f"guild/{self.id}/stickers"): - yield Sticker.from_dict(sticker) + return APIDataGen( + Sticker, + self._http.get(f"guild/{self.id}/stickers") + ) async def get_sticker(self, _id: Snowflake) -> Sticker: """|coro| @@ -1874,7 +1928,7 @@ async def delete_sticker(self, _id: Snowflake): """ await self._http.delete(f"guilds/{self.id}/stickers/{_id}") - async def get_webhooks(self) -> AsyncGenerator[Webhook, None]: + def get_webhooks(self) -> APIDataGen[Webhook]: """|coro| Returns an async generator of the guild webhooks. @@ -1883,9 +1937,276 @@ async def get_webhooks(self) -> AsyncGenerator[Webhook, None]: AsyncGenerator[:class:`~pincer.objects.guild.webhook.Webhook`, None] The guild webhook object. """ - data = await self._http.get(f"guilds/{self.id}/webhooks") - for webhook_data in data: - yield Webhook.from_dict(webhook_data) + + return APIDataGen( + Webhook, + self._http.get(f"guilds/{self.id}/webhooks") + ) + + def get_scheduled_events( + self, with_user_count: bool = False + ) -> APIDataGen[ScheduledEvent]: + """ + Returns an async generator of the guild scheduled events. + + Parameters + ---------- + with_user_count : :class:`bool` + Whether to include the user count in the scheduled event. + + Yields + ------ + :class:`~pincer.objects.guild.scheduled_event.ScheduledEvent` + The scheduled event object. + """ + + return APIDataGen( + ScheduledEvent, + self._http.get( + f"guilds/{self.id}/scheduled-events", + param={"with_user_count": with_user_count}, + ) + ) + + async def create_scheduled_event( + self, + name: str, + privacy_level: int, + entity_type: int, + scheduled_start_time: datetime, + scheduled_end_time: Optional[datetime] = None, + entity_metadata: Optional[str] = None, + channel_id: Optional[int] = None, + description: Optional[str] = None, + reason: Optional[str] = None, + ) -> ScheduledEvent: + """ + Create a new scheduled event for the guild. + + Parameters + ---------- + name : :class:`str` + The name of the scheduled event. + privacy_level : :class:`int` + The privacy level of the scheduled event. + entity_type : :class:`int` + The type of entity to be scheduled. + scheduled_start_time : :class:`datetime` + The scheduled start time of the event. + scheduled_end_time : Optional[:class:`datetime`] + The scheduled end time of the event. + entity_metadata : Optional[:class:`str`] + The metadata of the entity to be scheduled. + channel_id : Optional[:class:`int`] + The channel id of the channel to be scheduled. + description : Optional[:class:`str`] + The description of the scheduled event. + reason : Optional[:class:`str`] + The reason for creating the scheduled event. + + Raises + ------ + ValueError: + If an event is created in the past or if an event ends before it starts + + Returns + ------- + :class:`~pincer.objects.guild.scheduled_event.ScheduledEvent` + The newly created scheduled event. + """ + if scheduled_start_time < datetime.now(): + raise ValueError("An event cannot be created in the past") + + if ( + scheduled_end_time + and scheduled_end_time < scheduled_start_time + ): + raise ValueError("An event cannot start before it ends") + + data = await self._http.post( + f"guilds/{self.id}/scheduled-events", + data={ + "name": name, + "privacy_level": privacy_level, + "entity_type": entity_type, + "scheduled_start_time": scheduled_start_time.isoformat(), + "scheduled_end_time": scheduled_end_time.isoformat() + if scheduled_end_time is not None + else None, + "entity_metadata": entity_metadata, + "channel_id": channel_id, + "description": description, + }, + headers={"X-Audit-Log-Reason": reason}, + ) + return ScheduledEvent.from_dict(data) + + async def get_scheduled_event( + self, _id: int, with_user_count: bool = False + ) -> ScheduledEvent: + """ + Get a scheduled event by id. + + Parameters + ---------- + _id : :class:`int` + The id of the scheduled event. + with_user_count : :class:`bool` + Whether to include the user count in the scheduled event. + + Returns + ------- + :class:`~pincer.objects.guild.scheduled_event.ScheduledEvent` + The scheduled event object. + """ + data = await self._http.get( + f"guilds/{self.id}/scheduled-events/{_id}", + params={"with_user_count": with_user_count}, + ) + return ScheduledEvent.from_dict(data) + + async def modify_scheduled_event( + self, + _id: int, + name: Optional[str] = None, + entity_type: Optional[int] = None, + privacy_level: Optional[int] = None, + scheduled_start_time: Optional[datetime] = None, + scheduled_end_time: Optional[datetime] = None, + entity_metadata: Optional[str] = None, + channel_id: Optional[int] = None, + description: Optional[str] = None, + status: Optional[int] = None, + reason: Optional[str] = None, + ) -> ScheduledEvent: + """ + Modify a scheduled event. + + Parameters + ---------- + _id : :class:`int` + The id of the scheduled event. + name : Optional[:class:`str`] + The name of the scheduled event. + entity_type : Optional[:class:`int`] + The type of entity to be scheduled. + privacy_level : Optional[:class:`int`] + The privacy level of the scheduled event. + scheduled_start_time : Optional[:class:`datetime`] + The scheduled start time of the event. + scheduled_end_time : Optional[:class:`datetime`] + The scheduled end time of the event. + entity_metadata : Optional[:class:`str`] + The metadata of the entity to be scheduled. + channel_id : Optional[:class:`int`] + The channel id of the channel to be scheduled. + description : Optional[:class:`str`] + The description of the scheduled event. + status : Optional[:class:`int`] + The status of the scheduled event. + reason : Optional[:class:`str`] + The reason for modifying the scheduled event. + + Raises + ------ + :class:`ValueError` + If the scheduled event is in the past, + or if the scheduled end time is before the scheduled start time. + + Returns + ------- + :class:`~pincer.objects.guild.scheduled_event.ScheduledEvent` + The scheduled event object. + """ + if scheduled_start_time: + if scheduled_start_time < datetime.now(): + raise ValueError("An event cannot be created in the past") + + if ( + scheduled_end_time + and scheduled_end_time < scheduled_start_time + ): + raise ValueError("An event cannot start before it ends") + + kwargs: Dict[str, str] = remove_none( + { + "name": name, + "privacy_level": privacy_level, + "entity_type": entity_type, + "scheduled_start_time": scheduled_start_time.isoformat() + if scheduled_start_time is not None + else None, + "scheduled_end_time": scheduled_end_time.isoformat() + if scheduled_end_time is not None + else None, + "entity_metadata": entity_metadata, + "channel_id": channel_id, + "description": description, + "status": status, + } + ) + + data = await self._http.patch( + f"guilds/{self.id}/scheduled-events/{_id}", + data=kwargs, + headers={"X-Audit-Log-Reason": reason}, + ) + return ScheduledEvent.from_dict(data) + + async def delete_scheduled_event(self, _id: int): + """ + Delete a scheduled event. + + Parameters + ---------- + _id : :class:`int` + The id of the scheduled event. + """ + await self._http.delete(f"guilds/{self.id}/scheduled-events/{_id}") + + def get_guild_scheduled_event_users( + self, + _id: int, + limit: int = 100, + with_member: bool = False, + before: Optional[int] = None, + after: Optional[int] = None, + ) -> APIDataGen[GuildScheduledEventUser]: + """ + Get the users of a scheduled event. + + Parameters + ---------- + _id : :class:`int` + The id of the scheduled event. + limit : :class:`int` + The number of users to retrieve. + with_member : :class:`bool` + Whether to include the member object in the scheduled event user. + before : Optional[:class:`int`] + consider only users before given user id + after : Optional[:class:`int`] + consider only users after given user id + + Yields + ------ + :class:`~pincer.objects.guild.scheduled_event.GuildScheduledEventUser` + The scheduled event user object. + """ + params = remove_none({ + "limit": limit, + "with_member": with_member, + "before": before, + "after": after, + }) + + return APIDataGen( + GuildScheduledEventUser, + self._http.get( + f"guilds/{self.id}/scheduled-events/{_id}/users", + params=params, + ) + ) @classmethod def from_dict(cls, data) -> Guild: diff --git a/pincer/objects/guild/invite.py b/pincer/objects/guild/invite.py index d791b77e..be4f4ac3 100644 --- a/pincer/objects/guild/invite.py +++ b/pincer/objects/guild/invite.py @@ -32,6 +32,7 @@ class InviteTargetType(IntEnum): EMBEDDED_APPLICATION: An embedded application invite, e.g. poker-night etc. """ + STREAM = 1 EMBEDDED_APPLICATION = 2 @@ -51,6 +52,7 @@ class InviteStageInstance(APIObject): topic: :class:`str` the topic of the Stage instance (1-120 characters) """ + members: List[GuildMember] participant_count: int speaker_count: int @@ -101,6 +103,7 @@ class Invite(APIObject): created_at: APINullable[:class:`~pincer.utils.timestamp.Timestamp`] When this invite was created """ + # noqa: E501 channel: Channel @@ -130,3 +133,17 @@ def __str__(self) -> str: @property def link(self): return f"https://discord.gg/{self.code}" + + async def delete(self): + """Delete this invite. + + Raises + ------ + Forbidden + You do not have permission to delete this invite + NotFound + This invite does not exist + HTTPException + Deleting the invite failed + """ + await self._http.delete(f"guilds/{self.guild.id}/invites/{self.code}") diff --git a/pincer/objects/guild/overwrite.py b/pincer/objects/guild/overwrite.py index 59d5ddaf..b8b4cf60 100644 --- a/pincer/objects/guild/overwrite.py +++ b/pincer/objects/guild/overwrite.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING +from .permissions import Permissions from ...utils import APIObject if TYPE_CHECKING: @@ -31,3 +32,8 @@ class Overwrite(APIObject): type: int allow: str deny: str + + @property + def permissions(self) -> Permissions: + """Returns the permissions for this overwrite""" + return Permissions.from_ints(int(self.allow), int(self.deny)) \ No newline at end of file diff --git a/pincer/objects/guild/permissions.py b/pincer/objects/guild/permissions.py new file mode 100644 index 00000000..d8185ab2 --- /dev/null +++ b/pincer/objects/guild/permissions.py @@ -0,0 +1,262 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from __future__ import annotations + +from dataclasses import dataclass +from enum import IntFlag +from typing import Tuple, Optional + + +class PermissionEnum(IntFlag): + """ + Represents the permissions for a guild. + """ + CREATE_INSTANT_INVITE = 1 << 0 + KICK_MEMBERS = 1 << 1 + BAN_MEMBERS = 1 << 2 + ADMINISTRATOR = 1 << 3 + MANAGE_CHANNELS = 1 << 4 + MANAGE_GUIlD = 1 << 5 + ADD_REACTIONS = 1 << 6 + VIEW_AUDIT_LOG = 1 << 7 + PRIORITY_SPEAKER = 1 << 8 + STREAM = 1 << 9 + VIEW_CHANNEL = 1 << 10 + SEND_MESSAGES = 1 << 11 + SEND_TTS_MESSAGES = 1 << 12 + MANAGE_MESSAGES = 1 << 13 + EMBED_LINKS = 1 << 14 + ATTACH_FILES = 1 << 15 + READ_MESSAGE_HISTORY = 1 << 16 + MENTION_EVERYONE = 1 << 17 + USE_EXTERNAL_EMOJIS = 1 << 18 + VIEW_GUILD_INSIGHTS = 1 << 19 + CONNECT = 1 << 20 + SPEAK = 1 << 21 + MUTE_MEMBERS = 1 << 22 + DEAFEN_MEMBERS = 1 << 23 + MOVE_MEMBERS = 1 << 24 + USE_VAD = 1 << 25 + CHANGE_NICKNAME = 1 << 26 + MANAGE_NICKNAMES = 1 << 27 + MANAGE_ROLES = 1 << 28 + MANAGE_WEBHOOKS = 1 << 29 + MANAGE_EMOJIS_AND_STICKERS = 1 << 30 + USE_APPLICATION_COMMANDS = 1 << 31 + REQUEST_TO_SPEAK = 1 << 32 + MANAGE_EVENTS = 1 << 33 + MANAGE_THREADS = 1 << 34 + CREATE_PUBLIC_THREADS = 1 << 35 + CREATE_PRIVATE_THREADS = 1 << 36 + USE_EXTERNAL_STICKERS = 1 << 37 + SEND_MESSAGES_IN_THREADS = 1 << 38 + START_EMBEDDED_ACTIVITIES = 1 << 39 + MODERATE_MEMBERS = 1 << 40 + + +@dataclass +class Permissions: + """ + Allows for easier access to the permissions + + Parameters + ---------- + create_instant_invite: :class:Optional[:class:`bool`] + Allows creation of instant invites + kick_members: :class:Optional[:class:`bool`] + Allows kicking members + ban_members: :class:Optional[:class:`bool`] + Allows banning members + administrator: :class:Optional[:class:`bool`] + Allows all permissions and bypasses channel permission overwrites + manage_channels: :class:Optional[:class:`bool`] + Allows management and editing of channels + manage_guild: :class:Optional[:class:`bool`] + Allows management and editing of the guild + add_reactions: :class:Optional[:class:`bool`] + Allows for the addition of reactions to messages + view_audit_log: :class:Optional[:class:`bool`] + Allows for viewing of audit logs + priority_speaker: :class:Optional[:class:`bool`] + Allows for using priority speaker in a voice channel + stream: :class:Optional[:class:`bool`] + Allows the user to go live + view_channel: :class:Optional[:class:`bool`] + Allows guild members to view a channel, which includes reading messages in text channels + send_messages: :class:Optional[:class:`bool`] + Allows for sending messages in a channel (does not allow sending messages in threads) + send_tts_messages: :class:Optional[:class:`bool`] + Allows for sending of tts messages + manage_messages: :class:Optional[:class:`bool`] + Allows for deletion of other users messages + embed_links: :class:Optional[:class:`bool`] + Links sent by users with this permission will be auto-embedded + attach_files: :class:Optional[:class:`bool`] + Allows for uploading images and files + read_message_history: :class:Optional[:class:`bool`] + Allows for reading of message history + mention_everyone: :class:Optional[:class:`bool`] + Allows for using the @everyone tag to notify all users in a channel, and the @here tag to notify all online users in a channel + use_external_emojis: :class:Optional[:class:`bool`] + Allows the usage of custom emojis from other servers + view_guild_insights: :class:Optional[:class:`bool`] + Allows for viewing of guild insights + connect: :class:Optional[:class:`bool`] + Allows for joining of a voice channel + speak: :class:Optional[:class:`bool`] + Allows for speaking in a voice channel + mute_members: :class:Optional[:class:`bool`] + Allows for muting members in a voice channel + deafen_members: :class:Optional[:class:`bool`] + Allows for deafening of members in a voice channel + move_members: :class:Optional[:class:`bool`] + Allows for moving of members between voice channels + use_vad: :class:Optional[:class:`bool`] + Allows for using voice activity detection in a voice channel + change_nickname: :class:Optional[:class:`bool`] + Allows for modification of own nickname + manage_nicknames: :class:Optional[:class:`bool`] + Allows for modification of other users nicknames + manage_roles: :class:Optional[:class:`bool`] + Allows for management and editing of roles + manage_webhooks: :class:Optional[:class:`bool`] + Allows for management and editing of webhooks + manage_emojis_and_stickers: :class:Optional[:class:`bool`] + Allows for management and editing of emojis and stickers + use_application_commands: :class:Optional[:class:`bool`] + Allows for using application-specific commands + request_to_speak: :class:Optional[:class:`bool`] + Allows for requesting to speak in a voice channel + manage_events: :class:Optional[:class:`bool`] + Allows for management and editing of events + manage_threads: :class:Optional[:class:`bool`] + Allows for management and editing of threads + create_public_threads: :class:Optional[:class:`bool`] + Allows for the creation of public threads + create_private_threads: :class:Optional[:class:`bool`] + Allows for the creation of private threads + use_external_stickers: :class:Optional[:class:`bool`] + Allows for the usage of stickers from other servers + send_messages_in_threads: :class:Optional[:class:`bool`] + Allows for sending messages in threads + start_embedded_activities: :class:Optional[:class:`bool`] + Allows for starting of embedded activities + moderate_members: :class:Optional[:class:`bool`] + Allows for moderation of members in a guild + """ + + create_instant_invite: Optional[bool] = None + kick_members: Optional[bool] = None + ban_members: Optional[bool] = None + administrator: Optional[bool] = None + manage_channels: Optional[bool] = None + manage_guild: Optional[bool] = None + add_reactions: Optional[bool] = None + view_audit_log: Optional[bool] = None + priority_speaker: Optional[bool] = None + stream: Optional[bool] = None + view_channel: Optional[bool] = None + send_messages: Optional[bool] = None + send_tts_messages: Optional[bool] = None + manage_messages: Optional[bool] = None + embed_links: Optional[bool] = None + attach_files: Optional[bool] = None + read_message_history: Optional[bool] = None + mention_everyone: Optional[bool] = None + use_external_emojis: Optional[bool] = None + view_guild_insights: Optional[bool] = None + connect: Optional[bool] = None + speak: Optional[bool] = None + mute_members: Optional[bool] = None + deafen_members: Optional[bool] = None + move_members: Optional[bool] = None + use_vad: Optional[bool] = None + change_nickname: Optional[bool] = None + manage_nicknames: Optional[bool] = None + manage_roles: Optional[bool] = None + manage_webhooks: Optional[bool] = None + manage_emojis_and_stickers: Optional[bool] = None + use_application_commands: Optional[bool] = None + request_to_speak: Optional[bool] = None + manage_events: Optional[bool] = None + manage_threads: Optional[bool] = None + create_public_threads: Optional[bool] = None + create_private_threads: Optional[bool] = None + use_external_stickers: Optional[bool] = None + send_messages_in_threads: Optional[bool] = None + start_embedded_activities: Optional[bool] = None + moderate_members: Optional[bool] = None + + def __setattr__(self, name: str, value: Optional[bool]) -> None: + if not isinstance(value, bool) and value is not None: + raise ValueError(f"Permission {name!r} must be a boolean or None") + return super().__setattr__(name, value) + + @classmethod + def from_ints(cls, allow: int, deny: int) -> Permissions: + """ + Create a Permission object from an integer representation of the permissions (deny and allow) + + Parameters + ---------- + allow: :class:`int` + The integer representation of the permissions that are allowed + deny: :class:`int` + The integer representation of the permissions that are denied + """ + clsobj = cls() + + for enum in PermissionEnum: + value = None + if enum.value & allow: + value = True + elif enum.value & deny: + value = False + + setattr(clsobj, enum.name.lower(), value) + + return clsobj + + def to_ints(self) -> Tuple[int]: + """ + Convert the Permission object to an integer representation of the permissions (deny and allow) + + Returns + ------- + :class:`Tuple[:class:`int`]` + The integer representation of the permissions that are allowed and denied + """ + allow = 0 + deny = 0 + for enum in PermissionEnum: + if getattr(self, enum.name.lower()): + allow |= enum.value + elif getattr(self, enum.name.lower()) is False: + deny |= enum.value + + return allow, deny + + @property + def allow(self) -> int: + """ + Returns the integer representation of the permissions that are allowed + """ + allow = 0 + for enum in PermissionEnum: + if getattr(self, enum.name.lower()): + allow |= enum.value + + return allow + + @property + def deny(self) -> int: + """ + Returns the integer representation of the permissions that are denied + """ + deny = 0 + for enum in PermissionEnum: + if getattr(self, enum.name.lower()) is False: + deny |= enum.value + + return deny diff --git a/pincer/objects/guild/scheduled_events.py b/pincer/objects/guild/scheduled_events.py index 4f9f5eb2..d4172875 100644 --- a/pincer/objects/guild/scheduled_events.py +++ b/pincer/objects/guild/scheduled_events.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from ..guild.stage import PrivacyLevel + from ..guild.member import GuildMember from ..user.user import User from ...utils.snowflake import Snowflake from ...utils.timestamp import Timestamp @@ -77,22 +78,22 @@ class ScheduledEvent(APIObject): The status of the scheduled event. entity_type: :class:`~pincer.guild.schedule_events.GuildScheduledEventEntityType` The type of the scheduled event - channel_id: :class:`int` + channel_id: APINullable[:class:`int`] The channel id in which the scheduled event will be hosted, or null if scheduled entity type is EXTERNAL - creator_id: :class:`int` + creator_id: APINullable[:class:`int`] The user id of the creator of the scheduled event scheduled_end_time: str The time the scheduled event will end, required if entity_type is EXTERNAL - description: :class:`str` + description: APINullable[:class:`str`] The description of the scheduled event (0-1000 characters) - entity_id: :class:`int` + entity_id: APINullable[:class:`int`] The id of an entity associated with a guild scheduled event - entity_metadata: :class:`str` + entity_metadata: APINullable[:class:`str`] Additional metadata for the guild scheduled event - creator: :class:`~pincer.objects.user.user.User` + creator: APINullable[:class:`~pincer.objects.user.user.User`] The user who created the scheduled event - user_count: :class:`int` + user_count: APINullable[:class:`int`] The number of users who have joined the scheduled event """ id: Snowflake @@ -112,3 +113,22 @@ class ScheduledEvent(APIObject): entity_metadata: APINullable[str] = MISSING creator: APINullable[User] = MISSING user_count: APINullable[int] = MISSING + + +@dataclass +class GuildScheduledEventUser(APIObject): + """ + Represents a user who has joined a scheduled event. + + Attributes + ---------- + guild_scheduled_event_id: :class:`int` + the scheduled event id which the user subscribed to + user : :class:`~pincer.objects.user.user.User` + user which subscribed to an event + member : APINullable[:class:`~pincer.objects.guild.member.GuildMember`] + guild member data for this user for the guild which this event belongs to, if any + """ + guild_scheduled_event_id: Snowflake + user: User + member: APINullable[GuildMember] = MISSING diff --git a/pincer/objects/guild/stage.py b/pincer/objects/guild/stage.py index c6db89a5..9ca607d7 100644 --- a/pincer/objects/guild/stage.py +++ b/pincer/objects/guild/stage.py @@ -5,11 +5,12 @@ from dataclasses import dataclass from enum import IntEnum -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from ...utils.api_object import APIObject, ChannelProperty, GuildProperty if TYPE_CHECKING: + from ...client import Client from ...utils.snowflake import Snowflake @@ -23,6 +24,7 @@ class PrivacyLevel(IntEnum): GUILD_ONLY: The stage of for guild members only. """ + PUBLIC = 1 GUILD_ONLY = 2 @@ -46,9 +48,36 @@ class StageInstance(APIObject, ChannelProperty, GuildProperty): discoverable: :class:`bool` Is Stage Discovery enabled """ + id: Snowflake guild_id: Snowflake channel_id: Snowflake topic: str privacy_level: PrivacyLevel discoverable: bool + + @classmethod + async def from_id(cls, client: Client, _id: int) -> StageInstance: + return client.http.get(f"stage-instance/{_id}") + + async def modify( + self, + topic: Optional[str] = None, + privacy_level: Optional[PrivacyLevel] = None, + reason: Optional[str] = None, + ): + """|coro| + Updates fields of an existing Stage instance. + Requires the user to be a moderator of the Stage channel. + + Parameters + ---------- + topic : Optional[:class:`str`] + The topic of the Stage instance (1-120 characters) + privacy_level : Optional[:class:`~pincer.objects.guild.stage.PrivacyLevel`] + The privacy level of the Stage instance + reason : Optional[:class:`str`] + The reason for the modification + """ + + await self._client.modify_stage(self.id, topic, privacy_level, reason) diff --git a/pincer/objects/message/context.py b/pincer/objects/message/context.py index 9288230d..985840dd 100644 --- a/pincer/objects/message/context.py +++ b/pincer/objects/message/context.py @@ -29,7 +29,7 @@ class MessageContext: author: Union[:class:`~pincer.objects.guild.member.GuildMember`, :class:`~pincer.objects.user.user.User`] The user whom invoked the interaction. - command: :class:`~pincer.objects.app.command.ClientCommandStructure` + command: :class:`~pincer.objects.app.command.InteractableStructure` The local command object for the command to whom this context belongs. diff --git a/pincer/objects/message/sticker.py b/pincer/objects/message/sticker.py index a0a907c4..bce85b03 100644 --- a/pincer/objects/message/sticker.py +++ b/pincer/objects/message/sticker.py @@ -84,13 +84,13 @@ class Sticker(APIObject): the user that uploaded the guild sticker """ - description: Optional[str] format_type: StickerFormatType id: Snowflake name: str tags: str type: StickerType + description: Optional[str] = None available: APINullable[bool] = MISSING guild_id: APINullable[Snowflake] = MISSING pack_id: APINullable[Snowflake] = MISSING diff --git a/pincer/objects/message/user_message.py b/pincer/objects/message/user_message.py index cb0bcfcc..cde75fc7 100644 --- a/pincer/objects/message/user_message.py +++ b/pincer/objects/message/user_message.py @@ -5,7 +5,7 @@ from collections import defaultdict from dataclasses import dataclass -from enum import Enum, IntEnum +from enum import Enum, IntEnum, IntFlag from typing import TYPE_CHECKING, DefaultDict from .attachment import Attachment @@ -20,6 +20,7 @@ from ..guild.role import Role from ..user.user import User from ..._config import GatewayConfig +from ...utils.api_data import APIDataGen from ...utils.api_object import APIObject, GuildProperty, ChannelProperty from ...utils.snowflake import Snowflake from ...utils.types import MISSING, JSONSerializable @@ -113,7 +114,7 @@ class MessageActivityType(IntEnum): JOIN_REQUEST = 5 -class MessageFlags(IntEnum): +class MessageFlags(IntFlag): """Special message properties. Attributes @@ -198,7 +199,7 @@ class MessageType(IntEnum): APPLICATION_COMMAND: Slash command is used and responded to. THREAD_STARTER_MESSAGE: - The initial message in a thread when its created off a message. + The initial message in a thread when it's created off a message. GUILD_INVITE_REMINDER: ?? """ @@ -387,6 +388,22 @@ async def from_id( msg = await client.http.get(f"channels/{channel_id}/messages/{_id}") return cls.from_dict(msg) + async def crosspost(self) -> UserMessage: + """|coro| + Crosspost a message in a News Channel to following channels. + + This endpoint requires the ``SEND_MESSAGES`` permission, + if the current user sent the message, or additionally the + ``MANAGE_MESSAGES`` permission, for all other messages, + to be present for the current user. + + Returns + ------- + :class:`~pincer.objects.message.UserMessage` + The crossposted message + """ + return await self._client.crosspost_message(self.channel_id, self.id) + def __str__(self): return self.content @@ -455,9 +472,9 @@ async def remove_user_reaction(self, emoji: str, user_id: Snowflake): f"/{user_id}" ) - async def get_reactions( + def get_reactions( self, emoji: str, after: Snowflake = 0, limit=25 - ) -> Generator[User, None, None]: + ) -> APIDataGen[User]: # TODO: HTTP Client will need to refactored to allow parameters using aiohttp's system. """|coro| @@ -474,12 +491,13 @@ async def get_reactions( Max number of users to return (1-100). |default| ``25`` """ - - for user in await self._http.get( - f"/channels/{self.channel_id}/messages/{self.id}/reactions/{emoji}", - params={"after": after, "limit": limit}, - ): - yield User.from_dict(user) + return APIDataGen( + User, + self._http.get( + f"/channels/{self.channel_id}/messages/{self.id}/reactions/{emoji}", + params={"after": after, "limit": limit}, + ) + ) async def remove_all_reactions(self): """|coro| @@ -508,7 +526,7 @@ async def remove_emoji(self, emoji): f"/channels/{self.channel_id}/messages/{self.id}/reactions/{emoji}" ) - # TODO: Implement file (https://discord.com/developers/docs/resources/channel#edit-message) + # TODO: Implement file (https://discord.dev/resources/channel#edit-message) async def edit( self, content: str = None, diff --git a/pincer/utils/__init__.py b/pincer/utils/__init__.py index bc9e1c79..d1f4e2e6 100644 --- a/pincer/utils/__init__.py +++ b/pincer/utils/__init__.py @@ -9,6 +9,7 @@ from .extraction import get_index from .insertion import should_pass_cls, should_pass_ctx from .replace import replace +from .shards import calculate_shard_id from .signature import get_params, get_signature_and_params from .snowflake import Snowflake from .tasks import Task, TaskScheduler @@ -23,10 +24,10 @@ ) __all__ = ( - "", "APINullable", "APIObject", "ChannelProperty", - "CheckFunction", "Color", "Coro", "EventMgr", "GuildProperty", "MISSING", - "MissingType", "Snowflake", "Task", "TaskScheduler", "Timestamp", "chdir", - "choice_value_types", "get_index", "get_params", + "APINullable", "APIObject", "ChannelProperty", "CheckFunction", + "Color", "Coro", "EventMgr", "GuildProperty", "MISSING", "MissingType", + "Snowflake", "Task", "TaskScheduler", "Timestamp", "calculate_shard_id", + "chdir", "choice_value_types", "get_index", "get_params", "get_signature_and_params", "remove_none", "replace", "should_pass_cls", "should_pass_ctx" ) diff --git a/pincer/utils/api_data.py b/pincer/utils/api_data.py new file mode 100644 index 00000000..0262d92a --- /dev/null +++ b/pincer/utils/api_data.py @@ -0,0 +1,36 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from __future__ import annotations + +from typing import Generic, TypeVar, TYPE_CHECKING + +from pincer.utils import APIObject + +if TYPE_CHECKING: + from typing import Any, Coroutine, List, Type, Generator, AsyncIterator + +T = TypeVar('T') + + +class APIDataGen(Generic[T]): + + def __init__( + self, + factory: Type[T], + request_func: Coroutine[Any, None, Any] + ): + + self.fac = factory if isinstance(factory, APIObject) else factory.from_dict + self.request_func = request_func + + async def __async(self) -> List[T]: + data = await self.request_func + return [self.fac(i) for i in data] + + def __await__(self) -> Generator[Any, None, Any]: + return self.__async().__await__() + + async def __aiter__(self) -> AsyncIterator[List[T]]: + for item in await self: + yield self.fac(item) diff --git a/pincer/utils/api_object.py b/pincer/utils/api_object.py index 2001a5ec..2b22f117 100644 --- a/pincer/utils/api_object.py +++ b/pincer/utils/api_object.py @@ -5,7 +5,7 @@ import copy import logging -from dataclasses import dataclass, fields, _is_dataclass_instance +from dataclasses import fields, _is_dataclass_instance from enum import Enum, EnumMeta from inspect import getfullargspec from itertools import chain @@ -17,7 +17,6 @@ TypeVar, Any, TYPE_CHECKING, - List, get_type_hints, get_origin, get_args, @@ -100,7 +99,7 @@ def _http(self) -> HTTPClient: return self._client.http @classmethod - def link(cls, client: Client): + def bind_client(cls, client: Client): """ Links the object to the client. @@ -186,11 +185,9 @@ def __attr_convert(self, attr_value: Dict, attr_type: T) -> T: def __post_init__(self): TypeCache() - attributes = chain( - *( - get_type_hints(cls, globalns=TypeCache.cache).items() - for cls in chain(self.__class__.__bases__, (self,)) - ) + attributes = chain.from_iterable( + get_type_hints(cls, globalns=TypeCache.cache).items() + for cls in chain(self.__class__.__bases__, (self,)) ) for attr, attr_type in attributes: diff --git a/pincer/utils/event_mgr.py b/pincer/utils/event_mgr.py index 29802b72..acb48692 100644 --- a/pincer/utils/event_mgr.py +++ b/pincer/utils/event_mgr.py @@ -4,19 +4,19 @@ from __future__ import annotations from abc import ABC, abstractmethod -from asyncio import Event, wait_for as _wait_for, get_running_loop, TimeoutError +from asyncio import Event, wait_for as _wait_for, TimeoutError from collections import deque from typing import TYPE_CHECKING from ..exceptions import TimeoutError as PincerTimeoutError if TYPE_CHECKING: + from asyncio import AbstractEventLoop from typing import Any, List, Union, Optional from .types import CheckFunction class _Processable(ABC): - @abstractmethod def process(self, event_name: str, event_value: Any): """ @@ -89,11 +89,7 @@ class _Event(_Processable): returned later. """ - def __init__( - self, - event_name: str, - check: CheckFunction - ): + def __init__(self, event_name: str, check: CheckFunction): self.event_name = event_name self.check = check self.event = Event() @@ -104,14 +100,13 @@ async def wait(self): """Waits until ``self.event`` is set.""" await self.event.wait() - def process(self, event_name: str, event_value: Any) -> bool: + def process(self, event_name: str, event_value: Any): # TODO: fix docs """ Parameters ---------- event_name - args Returns ------- @@ -162,8 +157,6 @@ def process(self, event_name: str, event_value: Any): Parameters ---------- event_name - args - Returns ------- @@ -197,8 +190,9 @@ class EventMgr: The List of events that need to be processed. """ - def __init__(self): + def __init__(self, loop: AbstractEventLoop): self.event_list: List[_Processable] = [] + self.loop = loop def process_events(self, event_name, event_value): """ @@ -213,10 +207,7 @@ def process_events(self, event_name, event_value): event.process(event_name, event_value) async def wait_for( - self, - event_name: str, - check: CheckFunction, - timeout: Optional[float] + self, event_name: str, check: CheckFunction, timeout: Optional[float] ) -> Any: """ Parameters @@ -277,17 +268,13 @@ async def loop_for( loop_mgr = _LoopMgr(event_name, check) self.event_list.append(loop_mgr) - loop = get_running_loop() - while True: - start_time = loop.time() + start_time = self.loop.time() try: yield await _wait_for( loop_mgr.get_next(), - timeout=_lowest_value( - loop_timeout, iteration_timeout - ) + timeout=_lowest_value(loop_timeout, iteration_timeout), ) except TimeoutError: @@ -305,7 +292,7 @@ async def loop_for( # `not` can't be used here because there is a check for # `loop_timeout == 0` if loop_timeout is not None: - loop_timeout -= loop.time() - start_time + loop_timeout -= self.loop.time() - start_time # loop_timeout can be below 0 if the user's code in the for loop # takes longer than the time left in loop_timeout diff --git a/pincer/utils/replace.py b/pincer/utils/replace.py index dea1f4ce..d51d6c8d 100644 --- a/pincer/utils/replace.py +++ b/pincer/utils/replace.py @@ -1,10 +1,13 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + from typing import Any, Callable, Iterable, List, T def replace( - func: Callable[[Any], bool], iter: Iterable[T], new_item: T + func: Callable[[Any], bool], iter_: Iterable[T], new_item: T ) -> List[T]: return [ item if func(item) else new_item - for item in iter + for item in iter_ ] diff --git a/pincer/utils/shards.py b/pincer/utils/shards.py new file mode 100644 index 00000000..cc4c1a40 --- /dev/null +++ b/pincer/utils/shards.py @@ -0,0 +1,23 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import Union + from .snowflake import Snowflake + + +def calculate_shard_id(guild_id: Union[Snowflake, int], num_shards: int) -> int: + """Calculates the shard receiving the events for a specified guild + + Parameters + ---------- + guild_id : Optional[~pincer.utils.snowflake.Snowflake] + The guild_id of the shard to look for + num_shards : Optional[int] + The number of shards. + """ + return (guild_id >> 22) % num_shards diff --git a/setup.cfg b/setup.cfg index 2af3689e..ca6ae505 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = pincer -version = 0.15.3 +version = 0.16.0 description = Discord API wrapper rebuild from scratch. long_description = file: docs/PYPI.md long_description_content_type = text/markdown @@ -33,33 +33,34 @@ classifiers = include_package_data = True packages = pincer - pincer.middleware - pincer.core - pincer.utils - pincer.commands - pincer.commands.components pincer.objects - pincer.objects.voice + pincer.objects.guild pincer.objects.user - pincer.objects.events pincer.objects.app + pincer.objects.voice + pincer.objects.events pincer.objects.message - pincer.objects.guild + pincer.core + pincer.utils + pincer.commands + pincer.commands.components + pincer.middleware install_requires = aiohttp~=3.8 python_requires = >=3.8 [options.extras_require] testing = - coverage==6.2 + coverage==6.3.2 flake8==4.0.1 tox==3.24.4 - pre-commit==2.16.0 + pre-commit==2.17.0 pytest==6.2.5 pytest-cov==3.0.0 mypy==0.910 img = Pillow==8.4.0 + types-Pillow==9.0.6 speed = orjson>=3.5.4 Brotli>=1.0.9 diff --git a/tests/objects/guild/test_permission.py b/tests/objects/guild/test_permission.py new file mode 100644 index 00000000..d9c5c914 --- /dev/null +++ b/tests/objects/guild/test_permission.py @@ -0,0 +1,93 @@ +# Copyright Pincer 2021-Present +# Full MIT License can be found in `LICENSE` at the project root. + +from pincer.objects.guild.permissions import Permissions, PermissionEnum + + +class TestPermission: + @staticmethod + def test_valid_permissions(): + valid_perms = ( + "create_instant_invite", + "kick_members", + "ban_members", + "administrator", + "manage_channels", + "manage_guild", + "add_reactions", + "view_audit_log", + "priority_speaker", + "stream", + "view_channel", + "send_messages", + "send_tts_messages", + "manage_messages", + "embed_links", + "attach_files", + "read_message_history", + "mention_everyone", + "use_external_emojis", + "view_guild_insights", + "connect", + "speak", + "mute_members", + "deafen_members", + "move_members", + "use_vad", + "change_nickname", + "manage_nicknames", + "manage_roles", + "manage_webhooks", + "manage_emojis_and_stickers", + "use_application_commands", + "request_to_speak", + "manage_events", + "manage_threads", + "create_public_threads", + "create_private_threads", + "use_external_stickers", + "send_messages_in_threads", + "start_embedded_activities", + "moderate_members", + ) + + for perm in valid_perms: + assert hasattr(Permissions(), perm) + + @staticmethod + def test_from_int(): + assert Permissions.from_ints(1025, 268435472) == Permissions( + view_channel=True, + manage_channels=False, + create_instant_invite=True, + manage_roles=False, + ) + + assert Permissions.from_ints(0, 0) == Permissions() + + @staticmethod + def test_to_int(): + allow, deny = Permissions.to_ints(Permissions()) + assert allow == 0 + assert deny == 0 + + permission = Permissions() + for enum in PermissionEnum: + if getattr(permission, enum.name.lower()): + allow |= enum.value + elif getattr(permission, enum.name.lower()) is False: + deny |= enum.value + + assert Permissions.to_ints(Permissions()) == (0, 0) + + @staticmethod + def test_allow(): + permission = Permissions( + view_channel=True, + manage_channels=False, + create_instant_invite=True, + manage_roles=False, + ) + + assert permission.allow == 1025 + assert permission.deny == 268435472