twizzler_driver/request/
inflight.rs1use std::{
2 collections::HashMap,
3 mem::MaybeUninit,
4 sync::{Arc, Mutex},
5 task::{Poll, Waker},
6};
7
8use super::{
9 response_info::ResponseInfo,
10 submit::SubmitRequest,
11 summary::{AnySubmitSummary, SubmitSummary, SubmitSummaryWithResponses},
12};
13
14#[derive(Debug)]
15struct InFlightInner<R> {
16 waker: Option<Waker>,
17 ready: Option<AnySubmitSummary<R>>,
18 count: usize,
19 first_err: usize,
20 resps: Option<Vec<MaybeUninit<R>>>,
21 map: HashMap<u64, usize>,
22}
23
24impl<R> InFlightInner<R> {
25 fn new(resps: bool, len: usize) -> Self {
26 let mut s = Self {
27 waker: None,
28 ready: None,
29 count: 0,
30 first_err: usize::MAX,
31 resps: if resps {
32 Some(Vec::with_capacity(len))
33 } else {
34 None
35 },
36 map: HashMap::new(),
37 };
38 if let Some(v) = s.resps.as_mut() {
39 v.resize_with(len, || MaybeUninit::uninit());
40 }
41 s
42 }
43
44 fn finish(&mut self, val: AnySubmitSummary<R>) {
45 if self.ready.is_some() {
46 return;
47 }
48 self.ready = Some(val);
49 if let Some(w) = self.waker.take() {
50 w.wake();
51 }
52 }
53
54 fn count(&self) -> usize {
55 self.count
56 }
57
58 fn calc_summary(&mut self) -> AnySubmitSummary<R> {
59 if self.first_err == usize::MAX {
60 if let Some(resps) = self.resps.take() {
61 let arr = resps.into_raw_parts();
62 let na = unsafe { Vec::from_raw_parts(arr.0 as *mut R, arr.1, arr.2) };
63 AnySubmitSummary::Responses(na)
64 } else {
65 AnySubmitSummary::Done
66 }
67 } else {
68 if let Some(resps) = self.resps.take() {
69 let arr = resps.into_raw_parts();
70 let na = unsafe { Vec::from_raw_parts(arr.0 as *mut R, arr.1, arr.2) };
71 AnySubmitSummary::Errors(self.first_err, na)
72 } else {
73 AnySubmitSummary::Errors(self.first_err, vec![])
74 }
75 }
76 }
77
78 fn tally_resp(&mut self, resp: &ResponseInfo<R>)
79 where
80 R: Send + Copy,
81 {
82 self.count += 1;
83
84 if self.resps.is_some() {
85 let idx = *self
86 .map
87 .get(&resp.id())
88 .expect("failed to lookup ID in ID map");
89 if resp.is_err() && self.first_err > idx {
90 self.first_err = idx;
91 }
92 self.resps.as_mut().unwrap()[idx] = MaybeUninit::new(*resp.data());
93 } else {
94 if resp.is_err() {
95 self.first_err = 0;
96 }
97 }
98 }
99}
100
101#[derive(Debug)]
102pub(crate) struct InFlight<R> {
103 len: usize,
104 inner: Arc<Mutex<InFlightInner<R>>>,
105}
106
107impl<R> InFlight<R> {
108 pub(crate) fn new(len: usize, resps: bool) -> Self {
109 Self {
110 len,
111 inner: Arc::new(Mutex::new(InFlightInner::new(resps, len))),
112 }
113 }
114
115 pub(crate) fn finish(&self, summ: AnySubmitSummary<R>) {
116 let mut inner = self.inner.lock().unwrap();
117 inner.finish(summ);
118 }
119
120 pub(crate) fn insert_to_map<T>(&self, reqs: &[SubmitRequest<T>], idx_off: usize) {
121 let mut inner = self.inner.lock().unwrap();
122 if inner.resps.is_some() {
123 for (idx, req) in reqs.iter().enumerate() {
124 inner.map.insert(req.id(), idx_off + idx);
125 }
126 }
127 }
128
129 pub(crate) fn handle_resp(&self, resp: &ResponseInfo<R>)
130 where
131 R: Send + Copy,
132 {
133 let mut inner = self.inner.lock().unwrap();
134 inner.tally_resp(resp);
135 if inner.count() == self.len {
136 let summ = inner.calc_summary();
137 inner.finish(summ);
138 }
139 }
140}
141
142#[derive(Debug)]
143pub struct InFlightFuture<R> {
147 inflight: Arc<InFlight<R>>,
148}
149
150impl<R> std::future::Future for InFlightFuture<R> {
151 type Output = SubmitSummary;
152
153 fn poll(
154 self: std::pin::Pin<&mut Self>,
155 cx: &mut std::task::Context<'_>,
156 ) -> std::task::Poll<Self::Output> {
157 let mut inner = self.inflight.inner.lock().unwrap();
158 if let Some(out) = inner.ready.take() {
159 Poll::Ready(out.into())
160 } else {
161 inner.waker = Some(cx.waker().clone());
162 Poll::Pending
163 }
164 }
165}
166
167impl<R> InFlightFuture<R> {
168 pub(crate) fn new(inflight: Arc<InFlight<R>>) -> Self {
169 Self { inflight }
170 }
171}
172
173impl<R> InFlightFutureWithResponses<R> {
174 pub(crate) fn new(inflight: Arc<InFlight<R>>) -> Self {
175 Self { inflight }
176 }
177}
178
179#[derive(Debug)]
180pub struct InFlightFutureWithResponses<R> {
183 inflight: Arc<InFlight<R>>,
184}
185
186impl<R> std::future::Future for InFlightFutureWithResponses<R> {
187 type Output = SubmitSummaryWithResponses<R>;
188
189 fn poll(
190 self: std::pin::Pin<&mut Self>,
191 cx: &mut std::task::Context<'_>,
192 ) -> std::task::Poll<Self::Output> {
193 let mut inner = self.inflight.inner.lock().unwrap();
194 if let Some(out) = inner.ready.take() {
195 Poll::Ready(out.into())
196 } else {
197 inner.waker = Some(cx.waker().clone());
198 Poll::Pending
199 }
200 }
201}