1
0
mirror of https://github.com/ARM-software/devlib.git synced 2024-10-05 18:30:50 +01:00
devlib/tests/test_asyn.py
Douglas Raillard fb4e155696 utils/asyn: Replace nest_asyncio with greenlet
Provide an implementation of re-entrant asyncio.run() that is less
brittle than what greenback provides (e.g. no use of ctypes to poke
extension types).

The general idea of the implementation consists in treating the executed
coroutine as a generator, then turning that generator into a generator
implemented using greenlet. This allows a nested function to make the
top-level parent yield values on its behalf, as if every call was
annotated with "yield from".
2024-09-12 17:59:19 -05:00

609 lines
13 KiB
Python

#
# Copyright 2024 ARM Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import asyncio
from functools import partial
import contextvars
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from pytest import skip, raises
from devlib.utils.asyn import run, asynccontextmanager
class AsynTestExcep(Exception):
pass
class Awaitable:
def __await__(self):
return (yield self)
@contextmanager
def raises_and_bubble(cls):
try:
yield
except BaseException as e:
if isinstance(e, cls):
raise
else:
raise AssertionError(f'Did not raise instance of {cls}')
else:
raise AssertionError(f'Did not raise any exception')
@contextmanager
def coro_stop_iteration(x):
try:
yield
except StopIteration as e:
assert e.value == x
except BaseException:
raise
else:
raise AssertionError('Coroutine did not finish')
def _do_test_run(top_run):
async def test_run_basic():
async def f():
return 42
assert run(f()) == 42
top_run(test_run_basic())
async def test_run_basic_contextvars_get():
var = contextvars.ContextVar('var')
var.set(42)
async def f():
return var.get()
assert var.get() == 42
assert run(f()) == 42
top_run(test_run_basic_contextvars_get())
async def test_run_basic_contextvars_set():
var = contextvars.ContextVar('var')
async def f():
var.set(43)
var.set(42)
assert var.get() == 42
run(f())
assert var.get() == 43
top_run(test_run_basic_contextvars_set())
async def test_run_basic_raise():
async def f():
raise AsynTestExcep
with raises(AsynTestExcep):
run(f())
top_run(test_run_basic_raise())
async def test_run_basic_await():
async def nested():
return 42
async def f():
return await nested()
assert run(f()) == 42
top_run(test_run_basic_await())
async def test_run_basic_await_raise():
async def nested():
raise AsynTestExcep
async def f():
with raises_and_bubble(AsynTestExcep):
return await nested()
with raises(AsynTestExcep):
run(f())
top_run(test_run_basic_await_raise())
async def test_run_nested1():
async def nested():
return 42
async def f():
return run(nested())
assert run(f()) == 42
top_run(test_run_nested1())
async def test_run_nested1_raise():
async def nested():
raise AsynTestExcep
async def f():
with raises_and_bubble(AsynTestExcep):
return run(nested())
with raises(AsynTestExcep):
run(f())
top_run(test_run_nested1_raise())
async def test_run_nested2():
async def nested2():
return 42
async def nested1():
return run(nested2())
async def f():
return run(nested1())
assert run(f()) == 42
top_run(test_run_nested2())
async def test_run_nested2_raise():
async def nested2():
raise AsynTestExcep
async def nested1():
with raises_and_bubble(AsynTestExcep):
return run(nested2())
async def f():
with raises_and_bubble(AsynTestExcep):
return run(nested1())
with raises(AsynTestExcep):
run(f())
top_run(test_run_nested2_raise())
async def test_run_nested2_block():
async def nested2():
return 42
def nested1():
return run(nested2())
async def f():
return nested1()
assert run(f()) == 42
top_run(test_run_nested2_block())
async def test_run_nested2_block_raise():
async def nested2():
raise AsynTestExcep
def nested1():
with raises_and_bubble(AsynTestExcep):
return run(nested2())
async def f():
with raises_and_bubble(AsynTestExcep):
return nested1()
with raises(AsynTestExcep):
run(f())
top_run(test_run_nested2_block_raise())
async def test_coro_send():
async def f():
return await Awaitable()
coro = f()
coro.send(None)
with coro_stop_iteration(42):
coro.send(42)
top_run(test_coro_send())
async def test_coro_nested_send():
async def nested():
return await Awaitable()
async def f():
return await nested()
coro = f()
coro.send(None)
with coro_stop_iteration(42):
coro.send(42)
top_run(test_coro_nested_send())
async def test_coro_nested_send2():
future = asyncio.Future()
future.set_result(42)
async def nested():
return await future
async def f():
return run(nested())
assert run(f()) == 42
top_run(test_coro_nested_send2())
async def test_coro_nested_send3():
future = asyncio.Future()
future.set_result(42)
async def nested2():
return await future
async def nested():
return run(nested2())
async def f():
return run(nested())
assert run(f()) == 42
top_run(test_coro_nested_send3())
async def test_coro_throw():
async def f():
try:
await Awaitable()
except AsynTestExcep:
return 42
coro = f()
coro.send(None)
with coro_stop_iteration(42):
coro.throw(AsynTestExcep)
top_run(test_coro_throw())
async def test_coro_throw2():
async def f():
await Awaitable()
coro = f()
coro.send(None)
with raises(AsynTestExcep):
coro.throw(AsynTestExcep)
top_run(test_coro_throw2())
async def test_coro_nested_throw():
async def nested():
try:
await Awaitable()
except AsynTestExcep:
return 42
async def f():
return await nested()
coro = f()
coro.send(None)
with coro_stop_iteration(42):
coro.throw(AsynTestExcep)
top_run(test_coro_nested_throw())
async def test_coro_nested_throw2():
async def nested():
await Awaitable()
async def f():
with raises_and_bubble(AsynTestExcep):
await nested()
coro = f()
coro.send(None)
with raises(AsynTestExcep):
coro.throw(AsynTestExcep)
top_run(test_coro_nested_throw2())
async def test_coro_nested_throw3():
future = asyncio.Future()
future.set_exception(AsynTestExcep())
async def nested():
await future
async def f():
with raises_and_bubble(AsynTestExcep):
run(nested())
with raises(AsynTestExcep):
run(f())
top_run(test_coro_nested_throw3())
async def test_coro_nested_throw4():
future = asyncio.Future()
future.set_exception(AsynTestExcep())
async def nested2():
await future
async def nested():
return run(nested2())
async def f():
with raises_and_bubble(AsynTestExcep):
run(nested())
with raises(AsynTestExcep):
run(f())
top_run(test_coro_nested_throw4())
async def test_async_cm():
state = None
async def f():
return 43
@asynccontextmanager
async def cm():
nonlocal state
state = 'started'
await f()
try:
yield 42
finally:
await f()
state = 'finished'
async with cm() as x:
assert state == 'started'
assert x == 42
assert state == 'finished'
top_run(test_async_cm())
async def test_async_cm2():
state = None
async def f():
return 43
@asynccontextmanager
async def cm():
nonlocal state
state = 'started'
await f()
try:
await f()
yield 42
await f()
except AsynTestExcep:
await f()
# Swallow the exception
pass
finally:
await f()
state = 'finished'
async with cm() as x:
assert state == 'started'
raise AsynTestExcep()
assert state == 'finished'
top_run(test_async_cm2())
async def test_async_cm3():
state = None
async def f():
return 43
@asynccontextmanager
async def cm():
nonlocal state
state = 'started'
await f()
try:
yield 42
finally:
await f()
state = 'finished'
with cm() as x:
assert state == 'started'
assert x == 42
assert state == 'finished'
top_run(test_async_cm3())
def test_async_cm4():
state = None
async def f():
return 43
@asynccontextmanager
async def cm():
nonlocal state
state = 'started'
await f()
try:
yield 42
finally:
await f()
state = 'finished'
with cm() as x:
assert state == 'started'
assert x == 42
assert state == 'finished'
test_async_cm4()
def test_async_cm5():
@asynccontextmanager
async def cm_f():
yield 42
cm = cm_f()
assert top_run(cm.__aenter__()) == 42
assert not top_run(cm.__aexit__(None, None, None))
test_async_cm5()
def test_async_gen1():
async def agen_f():
for i in range(2):
yield i
agen = agen_f()
assert top_run(anext(agen)) == 0
assert top_run(anext(agen)) == 1
test_async_gen1()
def _test_in_thread(setup, test):
def f():
with setup() as run:
return test()
with ThreadPoolExecutor(max_workers=1) as pool:
return pool.submit(f).result()
def _test_run_with_setup(setup):
def run_with_existing_loop(coro):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Simulate case where devlib is ran in a context where the main app has
# set an event loop at some point
try:
return asyncio.run(coro)
finally:
loop.close()
def run_with_existing_loop2(coro):
# This is similar to how things are executed on IPython/jupyterlab
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
def run_with_to_thread(top_run, coro):
# Add a layer of asyncio.to_thread(), to simulate a case where users
# would be using the blocking API along with asyncio.to_thread() (code
# written before devlib gained async capabilities or wishing to
# preserve compat with older devlib versions)
async def wrapper():
return await asyncio.to_thread(
top_run, coro
)
return top_run(wrapper())
runners = [
run,
asyncio.run,
run_with_existing_loop,
run_with_existing_loop2,
partial(run_with_to_thread, run),
partial(run_with_to_thread, asyncio.run),
partial(run_with_to_thread, run_with_existing_loop),
partial(run_with_to_thread, run_with_existing_loop2),
]
for top_run in runners:
_test_in_thread(
setup,
partial(_do_test_run, top_run),
)
def test_run_stdlib():
@contextmanager
def setup():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
yield asyncio.run
finally:
loop.close()
_test_run_with_setup(setup)
def test_run_uvloop():
try:
import uvloop
except ImportError:
skip('uvloop not installed')
else:
@contextmanager
def setup():
if sys.version_info >= (3, 11):
with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner:
yield runner.run
else:
uvloop.install()
yield asyncio.run
_test_run_with_setup(setup)