Source code for twitchio.ext.commands.core

"""
The MIT License (MIT)

Copyright (c) 2017-present TwitchIO

Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the "Software"),
to deal in the Software without restriction, including without limitation
the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""

from __future__ import annotations
import inspect

import itertools
import copy
import types
from typing import Any, Union, Optional, Callable, Awaitable, Tuple, TYPE_CHECKING, List, Type, Set, TypeVar
from typing_extensions import Literal

from twitchio.abcs import Messageable
from .cooldowns import *
from .errors import *
from . import builtin_converter

if TYPE_CHECKING:
    import sys

    from twitchio import Message, Chatter, PartialChatter, Channel, User, PartialUser
    from . import Cog, Bot
    from .stringparser import StringParser

    if sys.version_info >= (3, 10):
        UnionT = Union[types.UnionType, Union]
    else:
        UnionT = Union


__all__ = ("Command", "command", "Group", "Context", "cooldown")


class EmptyArgumentSentinel:
    def __repr__(self) -> str:
        return "<EMPTY>"

    def __eq__(self, __value: object) -> bool:
        return False


EMPTY = EmptyArgumentSentinel()


def _boolconverter(_, param: str):
    param = param.lower()
    if param in {"yes", "y", "1", "true", "on"}:
        return True
    elif param in {"no", "n", "0", "false", "off"}:
        return False
    raise BadArgument(f"Expected a boolean value, got {param}")


class Command:
    """A class for implementing bot commands.

    Parameters
    ------------
    name: :class:`str`
        The name of the command.
    func: :class:`Callable`
        The coroutine that executes when the command is invoked.

    Attributes
    ------------
    name: :class:`str`
        The name of the command.
    cog: :class:`~twitchio.ext.commands.Cog`
        The cog this command belongs to.
    aliases: Optional[Union[:class:`list`, :class:`tuple`]]
        Aliases that can be used to also invoke the command.
    """

    def __init__(self, name: str, func: Callable, **attrs) -> None:
        if not inspect.iscoroutinefunction(func):
            raise TypeError("Command callback must be a coroutine.")
        self._callback = func
        self._checks = []
        self._cooldowns = []
        self._name = name

        self._instance = None
        self.cog = None
        self.parent: Optional[Group] = attrs.get("parent")

        try:
            self._checks.extend(func.__checks__)  # type: ignore
        except AttributeError:
            pass
        try:
            self._cooldowns.extend(func.__cooldowns__)  # type: ignore
        except AttributeError:
            pass
        self.aliases = attrs.get("aliases", None)
        sig = inspect.signature(func)
        self.params = sig.parameters.copy()  # type: ignore

        self.event_error = None
        self._before_invoke = None
        self._after_invoke = None
        self.no_global_checks = attrs.get("no_global_checks", False)

        for key, value in self.params.items():
            if isinstance(value.annotation, str):
                self.params[key] = value.replace(annotation=eval(value.annotation, func.__globals__))  # type: ignore

    @property
    def name(self) -> str:
        return self._name

    @property
    def full_name(self) -> str:
        if not self.parent:
            return self._name
        return f"{self.parent.full_name} {self._name}"

    def _is_optional_argument(self, converter: Any):
        return (getattr(converter, "__origin__", None) is Union or isinstance(converter, types.UnionType)) and type(
            None
        ) in converter.__args__

    def resolve_union_callback(self, name: str, converter: UnionT) -> Callable[[Context, str], Any]:
        # print(type(converter), converter.__args__)

        args = converter.__args__  # type: ignore # pyright doesnt like this

        async def _resolve(context: Context, arg: str) -> Any:
            t = EMPTY

            for original in args:
                underlying = self._resolve_converter(name, original, context)

                try:
                    t: Any = underlying(context, arg)
                    if inspect.iscoroutine(t):
                        t = await t

                    break
                except Exception as l:
                    t = EMPTY  # thisll get changed when t is a coroutine, but is still invalid, so roll it back
                    continue

            if t is EMPTY:
                raise UnionArgumentParsingFailed(name, args)

            return t

        return _resolve

    def resolve_optional_callback(self, name: str, converter: Any, context: Context) -> Callable[[Context, str], Any]:
        underlying = self._resolve_converter(name, converter.__args__[0], context)

        async def _resolve(context: Context, arg: str) -> Any:
            try:
                t: Any = underlying(context, arg)
                if inspect.iscoroutine(t):
                    t = await t

            except Exception:
                return EMPTY  # instruct the parser to roll back and ignore this argument

            return t

        return _resolve

    def _resolve_converter(
        self, name: str, converter: Union[Callable, Awaitable, type], ctx: Context
    ) -> Callable[..., Any]:
        if (
            isinstance(converter, type)
            and converter.__module__.startswith("twitchio")
            and converter in builtin_converter._mapping
        ):
            return self._convert_builtin_type(name, converter, builtin_converter._mapping[converter])

        elif converter is bool:
            converter = self._convert_builtin_type(name, bool, _boolconverter)

        elif converter in (str, int):
            original: type[str | int] = converter  # type: ignore
            converter = self._convert_builtin_type(name, original, lambda _, arg: original(arg))

        elif self._is_optional_argument(converter):
            return self.resolve_optional_callback(name, converter, ctx)

        elif isinstance(converter, types.UnionType) or getattr(converter, "__origin__", None) is Union:
            return self.resolve_union_callback(name, converter)  # type: ignore

        elif hasattr(converter, "__metadata__"):  # Annotated
            annotated = converter.__metadata__  # type: ignore
            return self._resolve_converter(name, annotated[0], ctx)

        return converter  # type: ignore

    def _convert_builtin_type(
        self,
        arg_name: str,
        original: type,
        converter: Union[Callable[[Context, str], Any], Callable[[Context, str], Awaitable[Any]]],
    ) -> Callable[[Context, str], Awaitable[Any]]:
        async def resolve(ctx, arg: str) -> Any:
            try:
                t = converter(ctx, arg)

                if inspect.iscoroutine(t):
                    t = await t

                return t
            except Exception as e:
                raise ArgumentParsingFailed(
                    f"Failed to convert `{arg}` to expected type {original.__name__} for argument `{arg_name}`",
                    original=e,
                    argname=arg_name,
                    expected=original,
                ) from e

        return resolve

    async def _convert_types(self, context: Context, param: inspect.Parameter, parsed: str) -> Any:
        converter = param.annotation

        if converter is param.empty:
            if param.default in (param.empty, None):
                converter = str
            else:
                converter = type(param.default)

        true_converter = self._resolve_converter(param.name, converter, context)

        try:
            argument = true_converter(context, parsed)
            if inspect.iscoroutine(argument):
                argument = await argument
        except BadArgument as e:
            if e.name is None:
                e.name = param.name

            raise
        except Exception as e:
            raise ArgumentParsingFailed(
                f"Failed to parse `{parsed}` for argument {param.name}", original=e, argname=param.name, expected=None
            ) from e
        return argument

    async def parse_args(self, context: Context, instance: Optional[Cog], parsed: dict, index=0) -> Tuple[list, dict]:
        if isinstance(self, Group):
            parsed = parsed.copy()
        iterator = iter(self.params.items())
        args = []
        kwargs = {}

        try:
            next(iterator)
            if instance:
                next(iterator)
        except StopIteration:
            raise TwitchCommandError("self or ctx is a required argument which is missing.")
        for _, param in iterator:
            index += 1
            if param.kind == param.POSITIONAL_OR_KEYWORD:
                try:
                    argument = parsed.pop(index)
                except (KeyError, IndexError):
                    if self._is_optional_argument(param.annotation):  # parameter is optional and at the end.
                        args.append(param.default if param.default is not param.empty else None)
                        continue

                    if param.default is param.empty:
                        raise MissingRequiredArgument(argname=param.name)

                    args.append(param.default)
                else:
                    _parsed_arg = await self._convert_types(context, param, argument)

                    if _parsed_arg is EMPTY:
                        parsed[index] = argument
                        index -= 1
                        args.append(param.default if param.default is not param.empty else None)

                        continue
                    else:
                        args.append(_parsed_arg)

            elif param.kind == param.KEYWORD_ONLY:
                rest = " ".join(parsed.values())
                if rest.startswith(" "):
                    rest = rest.lstrip(" ")
                if rest:
                    rest = await self._convert_types(context, param, rest)
                elif param.default is param.empty:
                    raise MissingRequiredArgument(argname=param.name)
                else:
                    rest = param.default
                kwargs[param.name] = rest
                parsed.clear()
                break
            elif param.kind == param.VAR_POSITIONAL:
                args.extend([await self._convert_types(context, param, argument) for argument in parsed.values()])
                parsed.clear()
                break
        if parsed:
            pass  # TODO Raise Too Many Arguments.
        return args, kwargs

    async def invoke(self, context: Context, *, index=0) -> None:
        # TODO Docs
        if not context.view:
            return

        async def try_run(func, *, to_command=False):
            try:
                await func
            except Exception as _e:
                if not to_command:
                    context.bot.run_event("error", _e)
                else:
                    context.bot.run_event("command_error", context, _e)

        try:
            args, kwargs = await self.parse_args(context, self._instance, context.view.words, index=index)
        except (MissingRequiredArgument, BadArgument) as e:
            if self.event_error:
                args_ = [self._instance, context] if self._instance else [context]
                await try_run(self.event_error(*args_, e))

            context.bot.run_event("command_error", context, e)
            return

        context.args, context.kwargs = args, kwargs
        check_result = await self.handle_checks(context)

        if check_result is not True:
            context.bot.run_event("command_error", context, check_result)
            return
        limited = self._run_cooldowns(context)

        if limited:
            context.bot.run_event("command_error", context, limited[0])
            return
        instance = self._instance
        args = [instance, context] if instance else [context]
        await try_run(context.bot.global_before_invoke(context))

        if self._before_invoke:
            await try_run(self._before_invoke(*args), to_command=True)
        try:
            await self._callback(*args, *context.args, **context.kwargs)
        except Exception as e:
            if self.event_error:
                await try_run(self.event_error(*args, e))
            context.bot.run_event("command_error", context, e)
        else:
            context.bot.run_event("command_complete", context)
        # Invoke our after command hooks
        if self._after_invoke:
            await try_run(self._after_invoke(*args), to_command=True)
        await try_run(context.bot.global_after_invoke(context))

    def _run_cooldowns(self, context: Context) -> Optional[List[CommandOnCooldown]]:
        try:
            buckets = self._cooldowns[0].get_buckets(context)
        except IndexError:
            return None
        expired = []

        try:
            for bucket in buckets:
                bucket.update_bucket(context)
        except CommandOnCooldown as e:
            expired.append(e)
        return expired

    async def handle_checks(self, context: Context) -> Union[Literal[True], Exception]:
        # TODO Docs

        if not self.no_global_checks:
            checks = [predicate for predicate in itertools.chain(context.bot._checks, self._checks)]
        else:
            checks = self._checks
        try:
            for predicate in checks:
                result = predicate(context)

                if inspect.isawaitable(result):
                    result = await result  # type: ignore
                if not result:
                    raise CheckFailure(f"The check {predicate} for command {self.name} failed.")
            if self.cog and not await self.cog.cog_check(context):
                raise CheckFailure(f"The cog check for command <{self.name}> failed.")
            return True
        except Exception as e:
            return e

    async def __call__(self, context: Context, *, index=0) -> None:
        await self.invoke(context, index=index)


class Group(Command):
    def __init__(self, *args, invoke_with_subcommand=False, **kwargs) -> None:
        super(Group, self).__init__(*args, **kwargs)
        self._sub_commands = {}
        self._invoke_with_subcommand = invoke_with_subcommand

    async def __call__(self, context: Context, *, index=0) -> None:
        if not context.view:
            return
        if not context.view.words:
            return await self.invoke(context, index=index)
        arg: Tuple[int, str] = list(context.view.words.items())[0]  # type: ignore
        if arg[1] in self._sub_commands:
            _ctx = copy.copy(context)
            _ctx.view = _ctx.view.copy()
            _ctx.view.words.pop(arg[0])
            await self._sub_commands[arg[1]](_ctx, index=arg[0])

            if self._invoke_with_subcommand:
                await self.invoke(context, index=index)
        else:
            await self.invoke(context, index=index)

    def command(
        self, *, name: str = None, aliases: Union[list, tuple] = None, cls=Command, no_global_checks=False
    ) -> Callable[[Callable], Command]:
        if cls and not inspect.isclass(cls):
            raise TypeError(f"cls must be of type <class> not <{type(cls)}>")

        def decorator(func: Callable):
            fname = name or func.__name__
            cmd = cls(name=fname, func=func, aliases=aliases, no_global_checks=no_global_checks, parent=self)
            self._sub_commands[cmd.name] = cmd
            if cmd.aliases:
                for a in cmd.aliases:
                    self._sub_commands[a] = cmd
            return cmd

        return decorator

    def group(
        self,
        *,
        name: str = None,
        aliases: Union[list, tuple] = None,
        cls: Type[Group] = None,
        no_global_checks=False,
        invoke_with_subcommand=False,
    ) -> Callable[[Callable], Group]:
        cls = cls or Group
        if cls and not inspect.isclass(cls):
            raise TypeError(f"cls must be of type <class> not <{type(cls)}>")

        def decorator(func: Callable):
            fname = name or func.__name__
            cmd = cls(
                name=fname,
                func=func,
                aliases=aliases,
                no_global_checks=no_global_checks,
                parent=self,
                invoke_with_subcommand=invoke_with_subcommand,
            )
            self._sub_commands[cmd.name] = cmd
            if cmd.aliases:
                for a in cmd.aliases:
                    self._sub_commands[a] = cmd
            return cmd

        return decorator


class Context(Messageable):
    """
    A class that represents the context in which a command is being invoked under.

    This class contains the meta data to help you understand more about the invocation context.
    This class is not created manually and is instead passed around to commands as the first parameter.

    Attributes
    -----------
    message: :class:`~twitchio.Message`
        The message that triggered the command being executed.
    channel: :class:`~twitchio.Channel`
        The channel the command was invoked in.
    author: Union[:class:`~twitchio.PartialChatter`, :class:`~twitchio.Chatter`]
        The Chatter object of the user in chat that invoked the command.
    prefix: Optional[:class:`str`]
        The prefix that was used to invoke the command.
    command: Optional[:class:`~twitchio.ext.commands.Command`]
        The command that was invoked
    cog: Optional[:class:`~twitchio.ext.commands.Cog`]
        The cog that contains the command that was invoked.
    args: Optional[List[:class:`Any`]]
        List of arguments that were passed to the command.
    kwargs: Optional[Dict[:class:`str`, :class:`Any`]]
        List of kwargs that were passed to the command.
    view: Optional[:class:`~twitchio.ext.commmands.StringParser`]
        StringParser object that breaks down the command string received.
    bot: :class:`~twitchio.ext.commands.Bot`
        The bot that contains the command that was invoked.
    """

    __messageable_channel__ = True

    def __init__(self, message: Message, bot: Bot, **attrs) -> None:
        self.message: Message = message
        self.channel: Channel = message.channel
        self.author: Union[Chatter, PartialChatter] = message.author

        self.prefix: Optional[str] = attrs.get("prefix")

        self.command: Optional[Command] = attrs.get("command")
        if self.command:
            self.cog: Optional[Cog] = self.command.cog
        self.args: Optional[list] = attrs.get("args")
        self.kwargs: Optional[dict] = attrs.get("kwargs")

        self.view: Optional[StringParser] = attrs.get("view")
        self.is_valid: bool = attrs.get("valid")

        self.bot: Bot = bot
        self._ws = self.author._ws

    def _fetch_channel(self) -> Messageable:
        return self.channel or self.author  # Abstract method

    def _fetch_websocket(self):
        return self._ws  # Abstract method

    def _fetch_message(self):
        return self.message  # Abstract method

    def _bot_is_mod(self) -> bool:
        if not self.channel:
            return False
        cache = self._ws._cache[self.channel._name]
        for user in cache:
            if user.name == self._ws.nick:
                try:
                    mod = user.is_mod
                except AttributeError:
                    return False
                return mod

    @property
    def chatters(self) -> Optional[Set[Chatter]]:
        """The channels current chatters."""
        try:
            users = self._ws._cache[self.channel._name]
        except (KeyError, AttributeError):
            return None
        return users

    @property
    def users(self) -> Optional[Set[Chatter]]:  # Alias to chatters
        """Alias to chatters."""
        return self.chatters

    def get_user(self, name: str) -> Optional[Union[PartialChatter, Chatter]]:
        """Retrieve a user from the channels user cache.

        Parameters
        -----------
        name: str
            The user's name to try and retrieve.

        Returns
        --------
        Union[:class:`twitchio.Chatter`, :class:`twitchio.PartialChatter`]
            Could be a :class:`twitchio.PartialChatter` depending on how the user joined the channel.
            Returns None if no user was found.
        """
        name = name.lower()

        if not self.channel:
            return None
        cache = self._ws._cache[self.channel._name]
        for user in cache:
            if user.name == name:
                return user
        return None

    async def reply(self, content: str):
        """|coro|


        Send a message in reply to the user who sent a message in the destination
        associated with the dataclass.

        Destination will be the context of which the message/command was sent.

        Parameters
        ------------
        content: str
            The content you wish to send as a message. The content must be a string.

        Raises
        --------
        InvalidContent
            Invalid content.
        """
        entity = self._fetch_channel()
        ws = self._fetch_websocket()
        message = self._fetch_message()

        self.check_content(content)
        self.check_bucket(channel=entity.name)

        try:
            name = entity.channel.name
        except AttributeError:
            name = entity.name
        if entity.__messageable_channel__:
            await ws.reply(message.id, f"PRIVMSG #{name} :{content}\r\n")
        else:
            await ws.send(f"PRIVMSG #jtv :/w {name} {content}\r\n")


C = TypeVar("C", bound="Command")
G = TypeVar("G", bound="Group")


def command(
    *, name: str = None, aliases: Union[list, tuple] = None, cls: type[C] = Command, no_global_checks=False
) -> Callable[[Callable], C]:
    if cls and not inspect.isclass(cls):
        raise TypeError(f"cls must be of type <class> not <{type(cls)}>")

    def decorator(func: Callable) -> C:
        fname = name or func.__name__
        return cls(
            name=fname,
            func=func,
            aliases=aliases,
            no_global_checks=no_global_checks,
        )

    return decorator


def group(
    *,
    name: str = None,
    aliases: Union[list, tuple] = None,
    cls: G = Group,
    no_global_checks=False,
    invoke_with_subcommand=False,
) -> Callable[[Callable], G]:
    if cls and not inspect.isclass(cls):
        raise TypeError(f"cls must be of type <class> not <{type(cls)}>")

    def decorator(func: Callable) -> G:
        fname = name or func.__name__
        return cls(
            name=fname,
            func=func,
            aliases=aliases,
            no_global_checks=no_global_checks,
            invoke_with_subcommand=invoke_with_subcommand,
        )

    return decorator


FN = TypeVar("FN")


def cooldown(rate, per, bucket=Bucket.default):
    def decorator(func: FN) -> FN:
        if isinstance(func, Command):
            func._cooldowns.append(Cooldown(rate, per, bucket))
        else:
            func.__cooldowns__ = [Cooldown(rate, per, bucket)]
        return func

    return decorator