twizzler_driver/request/
inflight.rs

1use std::{
2    collections::HashMap,
3    mem::MaybeUninit,
4    sync::{Arc, Mutex},
5    task::{Poll, Waker},
6};
7
8use super::{
9    response_info::ResponseInfo,
10    submit::SubmitRequest,
11    summary::{AnySubmitSummary, SubmitSummary, SubmitSummaryWithResponses},
12};
13
14#[derive(Debug)]
15struct InFlightInner<R> {
16    waker: Option<Waker>,
17    ready: Option<AnySubmitSummary<R>>,
18    count: usize,
19    first_err: usize,
20    resps: Option<Vec<MaybeUninit<R>>>,
21    map: HashMap<u64, usize>,
22}
23
24impl<R> InFlightInner<R> {
25    fn new(resps: bool, len: usize) -> Self {
26        let mut s = Self {
27            waker: None,
28            ready: None,
29            count: 0,
30            first_err: usize::MAX,
31            resps: if resps {
32                Some(Vec::with_capacity(len))
33            } else {
34                None
35            },
36            map: HashMap::new(),
37        };
38        if let Some(v) = s.resps.as_mut() {
39            v.resize_with(len, || MaybeUninit::uninit());
40        }
41        s
42    }
43
44    fn finish(&mut self, val: AnySubmitSummary<R>) {
45        if self.ready.is_some() {
46            return;
47        }
48        self.ready = Some(val);
49        if let Some(w) = self.waker.take() {
50            w.wake();
51        }
52    }
53
54    fn count(&self) -> usize {
55        self.count
56    }
57
58    fn calc_summary(&mut self) -> AnySubmitSummary<R> {
59        if self.first_err == usize::MAX {
60            if let Some(resps) = self.resps.take() {
61                let arr = resps.into_raw_parts();
62                let na = unsafe { Vec::from_raw_parts(arr.0 as *mut R, arr.1, arr.2) };
63                AnySubmitSummary::Responses(na)
64            } else {
65                AnySubmitSummary::Done
66            }
67        } else {
68            if let Some(resps) = self.resps.take() {
69                let arr = resps.into_raw_parts();
70                let na = unsafe { Vec::from_raw_parts(arr.0 as *mut R, arr.1, arr.2) };
71                AnySubmitSummary::Errors(self.first_err, na)
72            } else {
73                AnySubmitSummary::Errors(self.first_err, vec![])
74            }
75        }
76    }
77
78    fn tally_resp(&mut self, resp: &ResponseInfo<R>)
79    where
80        R: Send + Copy,
81    {
82        self.count += 1;
83
84        if self.resps.is_some() {
85            let idx = *self
86                .map
87                .get(&resp.id())
88                .expect("failed to lookup ID in ID map");
89            if resp.is_err() && self.first_err > idx {
90                self.first_err = idx;
91            }
92            self.resps.as_mut().unwrap()[idx] = MaybeUninit::new(*resp.data());
93        } else {
94            if resp.is_err() {
95                self.first_err = 0;
96            }
97        }
98    }
99}
100
101#[derive(Debug)]
102pub(crate) struct InFlight<R> {
103    len: usize,
104    inner: Arc<Mutex<InFlightInner<R>>>,
105}
106
107impl<R> InFlight<R> {
108    pub(crate) fn new(len: usize, resps: bool) -> Self {
109        Self {
110            len,
111            inner: Arc::new(Mutex::new(InFlightInner::new(resps, len))),
112        }
113    }
114
115    pub(crate) fn finish(&self, summ: AnySubmitSummary<R>) {
116        let mut inner = self.inner.lock().unwrap();
117        inner.finish(summ);
118    }
119
120    pub(crate) fn insert_to_map<T>(&self, reqs: &[SubmitRequest<T>], idx_off: usize) {
121        let mut inner = self.inner.lock().unwrap();
122        if inner.resps.is_some() {
123            for (idx, req) in reqs.iter().enumerate() {
124                inner.map.insert(req.id(), idx_off + idx);
125            }
126        }
127    }
128
129    pub(crate) fn handle_resp(&self, resp: &ResponseInfo<R>)
130    where
131        R: Send + Copy,
132    {
133        let mut inner = self.inner.lock().unwrap();
134        inner.tally_resp(resp);
135        if inner.count() == self.len {
136            let summ = inner.calc_summary();
137            inner.finish(summ);
138        }
139    }
140}
141
142#[derive(Debug)]
143/// A future for a set of in-flight requests for which we are uninterested in any responses from the
144/// device, we only care if the responses were completed successfully or not. On await, returns a
145/// [SubmitSummary].
146pub struct InFlightFuture<R> {
147    inflight: Arc<InFlight<R>>,
148}
149
150impl<R> std::future::Future for InFlightFuture<R> {
151    type Output = SubmitSummary;
152
153    fn poll(
154        self: std::pin::Pin<&mut Self>,
155        cx: &mut std::task::Context<'_>,
156    ) -> std::task::Poll<Self::Output> {
157        let mut inner = self.inflight.inner.lock().unwrap();
158        if let Some(out) = inner.ready.take() {
159            Poll::Ready(out.into())
160        } else {
161            inner.waker = Some(cx.waker().clone());
162            Poll::Pending
163        }
164    }
165}
166
167impl<R> InFlightFuture<R> {
168    pub(crate) fn new(inflight: Arc<InFlight<R>>) -> Self {
169        Self { inflight }
170    }
171}
172
173impl<R> InFlightFutureWithResponses<R> {
174    pub(crate) fn new(inflight: Arc<InFlight<R>>) -> Self {
175        Self { inflight }
176    }
177}
178
179#[derive(Debug)]
180/// A future for a set of in-flight requests for which we are interested in all responses from the
181/// device. On await, returns a [SubmitSummaryWithResponses].
182pub struct InFlightFutureWithResponses<R> {
183    inflight: Arc<InFlight<R>>,
184}
185
186impl<R> std::future::Future for InFlightFutureWithResponses<R> {
187    type Output = SubmitSummaryWithResponses<R>;
188
189    fn poll(
190        self: std::pin::Pin<&mut Self>,
191        cx: &mut std::task::Context<'_>,
192    ) -> std::task::Poll<Self::Output> {
193        let mut inner = self.inflight.inner.lock().unwrap();
194        if let Some(out) = inner.ready.take() {
195            Poll::Ready(out.into())
196        } else {
197            inner.waker = Some(cx.waker().clone());
198            Poll::Pending
199        }
200    }
201}