Source code for jishaku.repl.compilation

# -*- coding: utf-8 -*-

"""
jishaku.repl.compilation
~~~~~~~~~~~~~~~~~~~~~~~~

Constants, functions and classes related to classifying, compiling and executing Python code.

:copyright: (c) 2021 Devon (Gorialis) R
:license: MIT, see LICENSE for more details.

"""

import ast
import asyncio
import inspect
import linecache
import typing

import import_expression  # type: ignore

from jishaku.functools import AsyncSender
from jishaku.repl.scope import Scope
from jishaku.repl.walkers import KeywordTransformer

CORO_CODE = f"""
async def _repl_coroutine({{0}}):
    import asyncio
    from importlib import import_module as {import_expression.constants.IMPORTER}

    import aiohttp
    import discord
    from discord.ext import commands

    try:
        import jishaku
    except ImportError:
        jishaku = None  # keep working even if in panic recovery mode

    try:
        pass
    finally:
        _async_executor.scope.globals.update(locals())
"""


def wrap_code(code: str, args: str = '', auto_return: bool = True) -> ast.Module:
    """
    Compiles Python code into an async function or generator,
    and automatically adds return if the function body is a single evaluation.
    Also adds inline import expression support.
    """

    user_code: ast.Module = import_expression.parse(code, mode='exec')  # type: ignore
    mod: ast.Module = import_expression.parse(CORO_CODE.format(args), mode='exec')  # type: ignore

    for node in ast.walk(mod):
        node.lineno = -100_000
        node.end_lineno = -100_000

    definition = mod.body[-1]  # async def ...:
    assert isinstance(definition, ast.AsyncFunctionDef)

    try_block = definition.body[-1]  # try:
    assert isinstance(try_block, ast.Try)

    try_block.body.extend(user_code.body)

    ast.fix_missing_locations(mod)

    KeywordTransformer().generic_visit(try_block)

    # if auto return is disabled, we're done here
    if not auto_return:
        return mod

    last_expr = try_block.body[-1]

    # if the last part isn't an expression, ignore it
    if not isinstance(last_expr, ast.Expr):
        return mod

    # if the last expression is not a yield
    if not isinstance(last_expr.value, ast.Yield):
        # copy the value of the expression into a yield
        yield_stmt = ast.Yield(last_expr.value)
        ast.copy_location(yield_stmt, last_expr)
        # place the yield into its own expression
        yield_expr = ast.Expr(yield_stmt)
        ast.copy_location(yield_expr, last_expr)

        # place the yield where the original expression was
        try_block.body[-1] = yield_expr

    return mod


[docs] class AsyncCodeExecutor: # pylint: disable=too-few-public-methods """ Executes/evaluates Python code inside of an async function or generator. Example ------- .. code:: python3 total = 0 # prints 1, 2 and 3 async for x in AsyncCodeExecutor('yield 1; yield 2; yield 3'): total += x print(x) # prints 6 print(total) """ __slots__ = ('args', 'arg_names', 'code', 'loop', 'scope', 'source', '_function') def __init__( self, code: str, scope: typing.Optional[Scope] = None, arg_dict: typing.Optional[typing.Dict[str, typing.Any]] = None, convertables: typing.Optional[typing.Dict[str, str]] = None, loop: typing.Optional[asyncio.BaseEventLoop] = None, auto_return: bool = True, ): self.args = [self] self.arg_names = ['_async_executor'] if arg_dict: for key, value in arg_dict.items(): self.arg_names.append(key) self.args.append(value) self.source = code try: self.code = wrap_code(code, args=', '.join(self.arg_names), auto_return=auto_return) except (SyntaxError, IndentationError) as first_error: if not convertables: raise try: for key, value in convertables.items(): code = code.replace(key, value) self.code = wrap_code(code, args=', '.join(self.arg_names)) except (SyntaxError, IndentationError) as second_error: raise second_error from first_error self.scope = scope or Scope() self.loop = loop or asyncio.get_event_loop() self._function = None @property def function(self) -> typing.Callable[..., typing.Union[ typing.Awaitable[typing.Any], typing.AsyncGenerator[typing.Any, typing.Any] ]]: """ The function object produced from compiling the code. If the code has not been compiled yet, it will be done upon first access. """ if self._function is not None: return self._function exec(compile(self.code, '<repl>', 'exec'), self.scope.globals, self.scope.locals) # pylint: disable=exec-used self._function = self.scope.locals.get('_repl_coroutine') or self.scope.globals['_repl_coroutine'] return self._function
[docs] def create_linecache(self) -> typing.List[str]: """ Populates the line cache with the current source. Can be performed before printing a traceback to show correct source lines. """ lines = [line + '\n' for line in self.source.splitlines()] linecache.cache['<repl>'] = ( len(self.source), # Source length None, # Time modified (None bypasses expunge) lines, # Line list '<repl>' # 'True' filename ) return lines
def __aiter__(self) -> typing.AsyncGenerator[typing.Any, typing.Any]: return self.traverse(self.function)
[docs] async def traverse( self, func: typing.Callable[..., typing.Union[ typing.Awaitable[typing.Any], typing.AsyncGenerator[typing.Any, typing.Any] ]] ) -> typing.AsyncGenerator[typing.Any, typing.Any]: """ Traverses an async function or generator, yielding each result. This function is private. The class should be used as an iterator instead of using this method. """ try: if inspect.isasyncgenfunction(func): func_g: typing.Callable[..., typing.AsyncGenerator[typing.Any, typing.Any]] = func # type: ignore async for send, result in AsyncSender(func_g(*self.args)): # type: ignore send((yield result)) else: func_a: typing.Callable[..., typing.Awaitable[typing.Any]] = func # type: ignore yield await func_a(*self.args) except Exception: # pylint: disable=broad-except # Falsely populate the linecache to make the REPL line appear in tracebacks self.create_linecache() raise