mirror of
https://github.com/ARM-software/devlib.git
synced 2025-01-31 02:00:45 +00:00
fb4e155696
Provide an implementation of re-entrant asyncio.run() that is less brittle than what greenback provides (e.g. no use of ctypes to poke extension types). The general idea of the implementation consists in treating the executed coroutine as a generator, then turning that generator into a generator implemented using greenlet. This allows a nested function to make the top-level parent yield values on its behalf, as if every call was annotated with "yield from".
609 lines
13 KiB
Python
609 lines
13 KiB
Python
#
|
|
# Copyright 2024 ARM Limited
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import sys
|
|
import asyncio
|
|
from functools import partial
|
|
import contextvars
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from contextlib import contextmanager
|
|
|
|
from pytest import skip, raises
|
|
|
|
from devlib.utils.asyn import run, asynccontextmanager
|
|
|
|
|
|
class AsynTestExcep(Exception):
|
|
pass
|
|
|
|
|
|
class Awaitable:
|
|
def __await__(self):
|
|
return (yield self)
|
|
|
|
|
|
@contextmanager
|
|
def raises_and_bubble(cls):
|
|
try:
|
|
yield
|
|
except BaseException as e:
|
|
if isinstance(e, cls):
|
|
raise
|
|
else:
|
|
raise AssertionError(f'Did not raise instance of {cls}')
|
|
else:
|
|
raise AssertionError(f'Did not raise any exception')
|
|
|
|
|
|
@contextmanager
|
|
def coro_stop_iteration(x):
|
|
try:
|
|
yield
|
|
except StopIteration as e:
|
|
assert e.value == x
|
|
except BaseException:
|
|
raise
|
|
else:
|
|
raise AssertionError('Coroutine did not finish')
|
|
|
|
|
|
def _do_test_run(top_run):
|
|
|
|
async def test_run_basic():
|
|
|
|
async def f():
|
|
return 42
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_run_basic())
|
|
|
|
async def test_run_basic_contextvars_get():
|
|
var = contextvars.ContextVar('var')
|
|
var.set(42)
|
|
|
|
async def f():
|
|
return var.get()
|
|
|
|
assert var.get() == 42
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_run_basic_contextvars_get())
|
|
|
|
async def test_run_basic_contextvars_set():
|
|
var = contextvars.ContextVar('var')
|
|
|
|
async def f():
|
|
var.set(43)
|
|
|
|
var.set(42)
|
|
assert var.get() == 42
|
|
run(f())
|
|
assert var.get() == 43
|
|
|
|
top_run(test_run_basic_contextvars_set())
|
|
|
|
async def test_run_basic_raise():
|
|
|
|
async def f():
|
|
raise AsynTestExcep
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_run_basic_raise())
|
|
|
|
|
|
async def test_run_basic_await():
|
|
async def nested():
|
|
return 42
|
|
|
|
async def f():
|
|
return await nested()
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_run_basic_await())
|
|
|
|
|
|
async def test_run_basic_await_raise():
|
|
async def nested():
|
|
raise AsynTestExcep
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
return await nested()
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_run_basic_await_raise())
|
|
|
|
|
|
async def test_run_nested1():
|
|
async def nested():
|
|
return 42
|
|
|
|
async def f():
|
|
return run(nested())
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_run_nested1())
|
|
|
|
|
|
async def test_run_nested1_raise():
|
|
async def nested():
|
|
raise AsynTestExcep
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
return run(nested())
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_run_nested1_raise())
|
|
|
|
|
|
async def test_run_nested2():
|
|
async def nested2():
|
|
return 42
|
|
|
|
async def nested1():
|
|
return run(nested2())
|
|
|
|
async def f():
|
|
return run(nested1())
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_run_nested2())
|
|
|
|
|
|
async def test_run_nested2_raise():
|
|
async def nested2():
|
|
raise AsynTestExcep
|
|
|
|
async def nested1():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
return run(nested2())
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
return run(nested1())
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_run_nested2_raise())
|
|
|
|
|
|
async def test_run_nested2_block():
|
|
async def nested2():
|
|
return 42
|
|
|
|
def nested1():
|
|
return run(nested2())
|
|
|
|
async def f():
|
|
return nested1()
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_run_nested2_block())
|
|
|
|
|
|
async def test_run_nested2_block_raise():
|
|
async def nested2():
|
|
raise AsynTestExcep
|
|
|
|
def nested1():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
return run(nested2())
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
return nested1()
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_run_nested2_block_raise())
|
|
|
|
|
|
|
|
async def test_coro_send():
|
|
async def f():
|
|
return await Awaitable()
|
|
|
|
coro = f()
|
|
coro.send(None)
|
|
|
|
with coro_stop_iteration(42):
|
|
coro.send(42)
|
|
|
|
top_run(test_coro_send())
|
|
|
|
|
|
async def test_coro_nested_send():
|
|
async def nested():
|
|
return await Awaitable()
|
|
|
|
async def f():
|
|
return await nested()
|
|
|
|
coro = f()
|
|
coro.send(None)
|
|
|
|
with coro_stop_iteration(42):
|
|
coro.send(42)
|
|
|
|
top_run(test_coro_nested_send())
|
|
|
|
|
|
async def test_coro_nested_send2():
|
|
future = asyncio.Future()
|
|
future.set_result(42)
|
|
|
|
async def nested():
|
|
return await future
|
|
|
|
async def f():
|
|
return run(nested())
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_coro_nested_send2())
|
|
|
|
|
|
async def test_coro_nested_send3():
|
|
future = asyncio.Future()
|
|
future.set_result(42)
|
|
|
|
async def nested2():
|
|
return await future
|
|
|
|
async def nested():
|
|
return run(nested2())
|
|
|
|
async def f():
|
|
return run(nested())
|
|
|
|
assert run(f()) == 42
|
|
|
|
top_run(test_coro_nested_send3())
|
|
|
|
|
|
async def test_coro_throw():
|
|
async def f():
|
|
try:
|
|
await Awaitable()
|
|
except AsynTestExcep:
|
|
return 42
|
|
|
|
coro = f()
|
|
coro.send(None)
|
|
|
|
with coro_stop_iteration(42):
|
|
coro.throw(AsynTestExcep)
|
|
|
|
top_run(test_coro_throw())
|
|
|
|
|
|
async def test_coro_throw2():
|
|
async def f():
|
|
await Awaitable()
|
|
|
|
coro = f()
|
|
coro.send(None)
|
|
|
|
with raises(AsynTestExcep):
|
|
coro.throw(AsynTestExcep)
|
|
|
|
top_run(test_coro_throw2())
|
|
|
|
|
|
async def test_coro_nested_throw():
|
|
async def nested():
|
|
try:
|
|
await Awaitable()
|
|
except AsynTestExcep:
|
|
return 42
|
|
|
|
async def f():
|
|
return await nested()
|
|
|
|
coro = f()
|
|
coro.send(None)
|
|
|
|
with coro_stop_iteration(42):
|
|
coro.throw(AsynTestExcep)
|
|
|
|
top_run(test_coro_nested_throw())
|
|
|
|
|
|
async def test_coro_nested_throw2():
|
|
async def nested():
|
|
await Awaitable()
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
await nested()
|
|
|
|
coro = f()
|
|
coro.send(None)
|
|
|
|
with raises(AsynTestExcep):
|
|
coro.throw(AsynTestExcep)
|
|
|
|
top_run(test_coro_nested_throw2())
|
|
|
|
|
|
async def test_coro_nested_throw3():
|
|
future = asyncio.Future()
|
|
future.set_exception(AsynTestExcep())
|
|
|
|
async def nested():
|
|
await future
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
run(nested())
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_coro_nested_throw3())
|
|
|
|
|
|
async def test_coro_nested_throw4():
|
|
future = asyncio.Future()
|
|
future.set_exception(AsynTestExcep())
|
|
|
|
async def nested2():
|
|
await future
|
|
|
|
async def nested():
|
|
return run(nested2())
|
|
|
|
async def f():
|
|
with raises_and_bubble(AsynTestExcep):
|
|
run(nested())
|
|
|
|
with raises(AsynTestExcep):
|
|
run(f())
|
|
|
|
top_run(test_coro_nested_throw4())
|
|
|
|
async def test_async_cm():
|
|
state = None
|
|
|
|
async def f():
|
|
return 43
|
|
|
|
@asynccontextmanager
|
|
async def cm():
|
|
nonlocal state
|
|
state = 'started'
|
|
await f()
|
|
try:
|
|
yield 42
|
|
finally:
|
|
await f()
|
|
state = 'finished'
|
|
|
|
async with cm() as x:
|
|
assert state == 'started'
|
|
assert x == 42
|
|
|
|
assert state == 'finished'
|
|
|
|
top_run(test_async_cm())
|
|
|
|
async def test_async_cm2():
|
|
state = None
|
|
|
|
async def f():
|
|
return 43
|
|
|
|
@asynccontextmanager
|
|
async def cm():
|
|
nonlocal state
|
|
state = 'started'
|
|
await f()
|
|
try:
|
|
await f()
|
|
yield 42
|
|
await f()
|
|
except AsynTestExcep:
|
|
await f()
|
|
# Swallow the exception
|
|
pass
|
|
finally:
|
|
await f()
|
|
state = 'finished'
|
|
|
|
async with cm() as x:
|
|
assert state == 'started'
|
|
raise AsynTestExcep()
|
|
|
|
assert state == 'finished'
|
|
|
|
top_run(test_async_cm2())
|
|
|
|
async def test_async_cm3():
|
|
state = None
|
|
|
|
async def f():
|
|
return 43
|
|
|
|
@asynccontextmanager
|
|
async def cm():
|
|
nonlocal state
|
|
state = 'started'
|
|
await f()
|
|
try:
|
|
yield 42
|
|
finally:
|
|
await f()
|
|
state = 'finished'
|
|
|
|
with cm() as x:
|
|
assert state == 'started'
|
|
assert x == 42
|
|
|
|
assert state == 'finished'
|
|
|
|
top_run(test_async_cm3())
|
|
|
|
def test_async_cm4():
|
|
state = None
|
|
|
|
async def f():
|
|
return 43
|
|
|
|
@asynccontextmanager
|
|
async def cm():
|
|
nonlocal state
|
|
state = 'started'
|
|
await f()
|
|
try:
|
|
yield 42
|
|
finally:
|
|
await f()
|
|
state = 'finished'
|
|
|
|
with cm() as x:
|
|
assert state == 'started'
|
|
assert x == 42
|
|
|
|
assert state == 'finished'
|
|
|
|
test_async_cm4()
|
|
|
|
def test_async_cm5():
|
|
@asynccontextmanager
|
|
async def cm_f():
|
|
yield 42
|
|
|
|
cm = cm_f()
|
|
assert top_run(cm.__aenter__()) == 42
|
|
assert not top_run(cm.__aexit__(None, None, None))
|
|
|
|
test_async_cm5()
|
|
|
|
def test_async_gen1():
|
|
async def agen_f():
|
|
for i in range(2):
|
|
yield i
|
|
|
|
agen = agen_f()
|
|
assert top_run(anext(agen)) == 0
|
|
assert top_run(anext(agen)) == 1
|
|
|
|
test_async_gen1()
|
|
|
|
|
|
def _test_in_thread(setup, test):
|
|
def f():
|
|
with setup() as run:
|
|
return test()
|
|
|
|
with ThreadPoolExecutor(max_workers=1) as pool:
|
|
return pool.submit(f).result()
|
|
|
|
|
|
def _test_run_with_setup(setup):
|
|
def run_with_existing_loop(coro):
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
# Simulate case where devlib is ran in a context where the main app has
|
|
# set an event loop at some point
|
|
try:
|
|
return asyncio.run(coro)
|
|
finally:
|
|
loop.close()
|
|
|
|
def run_with_existing_loop2(coro):
|
|
# This is similar to how things are executed on IPython/jupyterlab
|
|
loop = asyncio.new_event_loop()
|
|
try:
|
|
return loop.run_until_complete(coro)
|
|
finally:
|
|
loop.close()
|
|
|
|
def run_with_to_thread(top_run, coro):
|
|
# Add a layer of asyncio.to_thread(), to simulate a case where users
|
|
# would be using the blocking API along with asyncio.to_thread() (code
|
|
# written before devlib gained async capabilities or wishing to
|
|
# preserve compat with older devlib versions)
|
|
async def wrapper():
|
|
return await asyncio.to_thread(
|
|
top_run, coro
|
|
)
|
|
return top_run(wrapper())
|
|
|
|
|
|
runners = [
|
|
run,
|
|
asyncio.run,
|
|
run_with_existing_loop,
|
|
run_with_existing_loop2,
|
|
|
|
partial(run_with_to_thread, run),
|
|
partial(run_with_to_thread, asyncio.run),
|
|
partial(run_with_to_thread, run_with_existing_loop),
|
|
partial(run_with_to_thread, run_with_existing_loop2),
|
|
]
|
|
|
|
for top_run in runners:
|
|
_test_in_thread(
|
|
setup,
|
|
partial(_do_test_run, top_run),
|
|
)
|
|
|
|
|
|
def test_run_stdlib():
|
|
@contextmanager
|
|
def setup():
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
try:
|
|
yield asyncio.run
|
|
finally:
|
|
loop.close()
|
|
|
|
_test_run_with_setup(setup)
|
|
|
|
|
|
def test_run_uvloop():
|
|
try:
|
|
import uvloop
|
|
except ImportError:
|
|
skip('uvloop not installed')
|
|
else:
|
|
@contextmanager
|
|
def setup():
|
|
if sys.version_info >= (3, 11):
|
|
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
|
|
yield runner.run
|
|
else:
|
|
uvloop.install()
|
|
yield asyncio.run
|
|
|
|
_test_run_with_setup(setup)
|