twizzler_driver/request/
requester.rs1use std::{
2 collections::HashMap,
3 sync::{
4 atomic::{AtomicU32, Ordering},
5 Arc, Mutex,
6 },
7};
8
9use super::{
10 async_ids::AsyncIdAllocator,
11 inflight::{InFlight, InFlightFuture, InFlightFutureWithResponses},
12 response_info::ResponseInfo,
13 submit::{SubmitError, SubmitRequest},
14 summary::AnySubmitSummary,
15 RequestDriver,
16};
17
18const OK: u32 = 0;
19const SHUTDOWN: u32 = 1;
20
21pub struct Requester<T: RequestDriver> {
23 driver: T,
24 inflights: Mutex<HashMap<u64, Arc<InFlight<T::Response>>>>,
25 ids: AsyncIdAllocator,
26 state: AtomicU32,
27}
28
29impl<T: RequestDriver> Requester<T> {
30 pub fn driver(&self) -> &T {
32 &self.driver
33 }
34
35 pub fn is_shutdown(&self) -> bool {
37 self.state.load(Ordering::SeqCst) == SHUTDOWN
38 }
39
40 pub fn new(driver: T) -> Self {
42 Self {
43 ids: AsyncIdAllocator::new(T::NUM_IDS),
44 driver,
45 inflights: Mutex::new(HashMap::new()),
46 state: AtomicU32::new(OK),
47 }
48 }
49
50 async fn allocate_ids(&self, reqs: &mut [SubmitRequest<T::Request>]) -> usize {
51 for (num, req) in reqs.iter_mut().enumerate() {
52 if num == 0 {
53 req.set_id(self.ids.next().await);
54 } else {
55 if let Some(id) = self.ids.try_next() {
56 req.set_id(id);
57 } else {
58 return num;
59 }
60 }
61 }
62 reqs.len()
63 }
64
65 fn release_id(&self, id: u64) {
66 self.ids.release_id(id);
67 }
68
69 fn map_inflight(
70 &self,
71 inflight: Arc<InFlight<T::Response>>,
72 reqs: &[SubmitRequest<T::Request>],
73 idx_off: usize,
74 ) {
75 {
76 let mut map = self.inflights.lock().unwrap();
77 for req in reqs {
78 if map.insert(req.id(), inflight.clone()).is_some() {
79 panic!("tried to map existing in-flight request");
80 }
81 }
82 }
83 inflight.insert_to_map(reqs, idx_off);
84 }
85
86 async fn do_submit(
87 &self,
88 inflight: Arc<InFlight<T::Response>>,
89 reqs: &mut [SubmitRequest<T::Request>],
90 ) -> Result<(), SubmitError<T::SubmitError>> {
91 let mut idx = 0;
92 while idx < reqs.len() {
93 let count = self.allocate_ids(&mut reqs[idx..]).await;
94 self.map_inflight(inflight.clone(), &reqs[idx..(idx + count)], idx);
95 self.driver
96 .submit(&mut reqs[idx..(idx + count)])
97 .await
98 .map_err(|e| SubmitError::DriverError(e))?;
99 idx += count;
100 }
101 Ok(())
102 }
103
104 pub async fn submit(
108 &self,
109 reqs: &mut [SubmitRequest<T::Request>],
110 ) -> Result<InFlightFuture<T::Response>, SubmitError<T::SubmitError>> {
111 if self.is_shutdown() {
112 return Err(SubmitError::IsShutdown);
113 }
114 let inflight = Arc::new(InFlight::new(reqs.len(), false));
115
116 self.do_submit(inflight.clone(), reqs).await?;
117 Ok(InFlightFuture::new(inflight))
118 }
119
120 pub async fn submit_for_response(
124 &self,
125 reqs: &mut [SubmitRequest<T::Request>],
126 ) -> Result<InFlightFutureWithResponses<T::Response>, SubmitError<T::SubmitError>> {
127 if self.is_shutdown() {
128 return Err(SubmitError::IsShutdown);
129 }
130 let inflight = Arc::new(InFlight::new(reqs.len(), true));
131 self.do_submit(inflight.clone(), reqs).await?;
132 Ok(InFlightFutureWithResponses::new(inflight))
133 }
134
135 pub fn shutdown(&self) {
137 self.state.store(SHUTDOWN, Ordering::SeqCst);
138 let mut inflights = self.inflights.lock().unwrap();
139 for (_, inflight) in inflights.drain() {
140 inflight.finish(AnySubmitSummary::Shutdown);
141 }
142 }
143
144 fn take_inflight(&self, id: u64) -> Option<Arc<InFlight<T::Response>>> {
145 self.inflights.lock().unwrap().remove(&id)
146 }
147
148 pub fn finish(&self, resps: &[ResponseInfo<T::Response>]) {
152 if self.is_shutdown() {
153 return;
154 }
155 for resp in resps {
156 let inflight = self.take_inflight(resp.id());
157 if let Some(inflight) = inflight {
158 inflight.handle_resp(resp);
159 }
160
161 self.release_id(resp.id());
162 }
163 }
164}