twizzler_queue/
sender_queue.rs1use std::{
2 collections::BTreeMap,
3 future::Future,
4 pin::Pin,
5 sync::{
6 atomic::{AtomicU32, Ordering},
7 Arc, Mutex,
8 },
9 task::{Poll, Waker},
10};
11
12use async_io::Async;
13use futures::FutureExt;
14use twizzler_queue_raw::{QueueError, ReceiveFlags, SubmissionFlags};
15
16use crate::Queue;
17
18struct QueueSenderInner<S, C> {
19 queue: Queue<S, C>,
20}
21
22struct WaitPoint<C> {
23 item: Option<(u32, C)>,
24 waker: Option<Waker>,
25}
26
27struct WaitPointFuture<'a, S: Copy + Send + Sync, C: Copy + Send + Sync> {
28 state: Arc<Mutex<WaitPoint<C>>>,
29 sender: &'a QueueSender<S, C>,
30}
31
32impl<'a, S: Copy + Send + Sync, C: Copy + Send + Sync> Future for WaitPointFuture<'a, S, C> {
33 type Output = Result<(u32, C), QueueError>;
34
35 fn poll(
36 self: std::pin::Pin<&mut Self>,
37 cx: &mut std::task::Context<'_>,
38 ) -> std::task::Poll<Self::Output> {
39 if let Some((id, item)) = self.sender.poll_completions() {
40 self.sender.handle_completion(id, item);
41 }
42 let mut state = self.state.lock().unwrap();
43 if let Some(item) = state.item.take() {
44 Poll::Ready(Ok(item))
45 } else {
46 state.waker = Some(cx.waker().clone());
47 Poll::Pending
48 }
49 }
50}
51
52pub struct QueueSender<S: Copy, C: Copy> {
58 counter: AtomicU32,
59 reuse: Mutex<Vec<u32>>,
60 inner: Async<Pin<Box<QueueSenderInner<S, C>>>>,
61 calls: Mutex<BTreeMap<u32, Arc<Mutex<WaitPoint<C>>>>>,
62}
63
64impl<S: Copy, C: Copy> twizzler_futures::TwizzlerWaitable for QueueSenderInner<S, C> {
65 fn wait_item_read(&self) -> twizzler_abi::syscall::ThreadSyncSleep {
66 self.queue.setup_read_com_sleep()
67 }
68
69 fn wait_item_write(&self) -> twizzler_abi::syscall::ThreadSyncSleep {
70 self.queue.setup_write_sub_sleep()
71 }
72}
73
74impl<S: Copy + Sync + Send, C: Copy + Send + Sync> QueueSender<S, C> {
75 pub fn new(queue: Queue<S, C>) -> Self {
77 Self {
78 counter: AtomicU32::new(0),
79 reuse: Mutex::new(vec![]),
80 inner: Async::new(QueueSenderInner { queue }).unwrap(),
81 calls: Mutex::new(BTreeMap::new()),
82 }
83 }
84
85 fn next_id(&self) -> u32 {
86 let mut reuse = self.reuse.lock().unwrap();
87 reuse
88 .pop()
89 .unwrap_or_else(|| self.counter.fetch_add(1, Ordering::SeqCst))
90 }
91
92 fn release_id(&self, id: u32) {
93 self.reuse.lock().unwrap().push(id)
94 }
95
96 fn poll_completions(&self) -> Option<(u32, C)> {
97 self.inner
98 .get_ref()
99 .queue
100 .get_completion(ReceiveFlags::NON_BLOCK)
101 .ok()
102 }
103
104 fn handle_completion(&self, id: u32, item: C) {
105 let mut calls = self.calls.lock().unwrap();
106 let call = calls
107 .remove(&id)
108 .expect("failed to find registered callback");
109 let mut call = call.lock().unwrap();
110 call.item = Some((id, item));
111 if let Some(waker) = call.waker.take() {
112 waker.wake();
113 }
114 }
115
116 pub fn submit_no_wait(&self, item: S, flags: SubmissionFlags) {
121 let _ = self
122 .inner
123 .get_ref()
124 .queue
125 .submit(self.next_id(), item, flags);
126 }
127
128 pub async fn submit_and_wait(&self, item: S) -> Result<C, std::io::Error> {
130 let id = self.next_id();
131 let state = Arc::new(Mutex::new(WaitPoint::<C> {
132 item: None,
133 waker: None,
134 }));
135 {
136 let mut calls = self.calls.lock().unwrap();
137 calls.insert(id, state.clone());
138 drop(calls);
139 }
140 if let Some((id, item)) = self.poll_completions() {
141 self.handle_completion(id, item);
142 }
143 self.inner
144 .write_with(|inner| {
145 inner
146 .queue
147 .submit(id, item, SubmissionFlags::NON_BLOCK)
148 .map_err(|e| e.into())
149 })
150 .await?;
151
152 let waiter = WaitPointFuture::<S, C> {
153 state,
154 sender: self,
155 };
156 let mut item = Box::pin(async { waiter.await }).fuse();
157 let mut recv = Box::pin(async {
158 loop {
159 let (id, item) = self
160 .inner
161 .read_with(|inner| {
162 inner
163 .queue
164 .get_completion(ReceiveFlags::NON_BLOCK)
165 .map_err(|e| e.into())
166 })
167 .await
168 .unwrap();
169 self.handle_completion(id, item);
170 }
171 })
172 .fuse();
173 let result = futures::select! {
174 item_res = item => item_res,
175 recv_res = recv => recv_res,
176 }?;
177 self.release_id(id);
178 Ok(result.1)
179 }
180}