1use std::{
5 alloc::Layout,
6 mem::{align_of, size_of},
7 ptr::NonNull,
8};
9
10use tracing::{error, trace};
11use twizzler_rt_abi::thread::TlsIndex;
12
13pub use crate::arch::Tcb;
15use crate::{
16 arch::{get_tls_variant, MINIMUM_TLS_ALIGNMENT},
17 compartment::Compartment,
18 DynlinkError, DynlinkErrorKind,
19};
20
21pub enum TlsVariant {
23 Variant1,
24 Variant2,
25}
26
27#[derive(Clone)]
28pub(crate) struct TlsInfo {
29 gen: u64,
31 alloc_size_mods: usize,
33 max_align: usize,
35 offset: usize,
37 pub(crate) tls_mods: Vec<TlsModule>,
38}
39
40impl TlsInfo {
41 pub(crate) fn new(gen: u64) -> Self {
42 Self {
43 gen,
44 alloc_size_mods: Default::default(),
45 max_align: MINIMUM_TLS_ALIGNMENT,
46 tls_mods: Default::default(),
47 offset: 0,
48 }
49 }
50
51 pub(crate) fn clone_to_new_gen(&self, new_gen: u64) -> Self {
52 Self {
53 gen: new_gen,
54 alloc_size_mods: self.alloc_size_mods,
55 max_align: self.max_align,
56 tls_mods: self.tls_mods.clone(),
57 offset: self.offset,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
63pub struct TlsModule {
64 pub is_static: bool,
65 pub template_addr: usize,
66 pub template_filesz: usize,
67 pub template_memsz: usize,
68 pub template_align: usize,
69 pub offset: Option<usize>,
70 pub id: Option<TlsModId>,
71}
72
73impl TlsModule {
74 pub(crate) fn new_static(
75 template_addr: usize,
76 template_filesz: usize,
77 template_memsz: usize,
78 template_align: usize,
79 ) -> Self {
80 Self {
81 is_static: true,
82 template_addr,
83 template_filesz,
84 template_memsz,
85 template_align,
86 offset: None,
87 id: None,
88 }
89 }
90}
91
92impl TlsInfo {
93 pub fn insert(&mut self, mut tm: TlsModule) -> TlsModId {
94 self.alloc_size_mods += tm.template_memsz;
96 self.max_align = std::cmp::max(self.max_align, tm.template_align);
97 self.max_align = self.max_align.next_power_of_two();
98
99 let id = match get_tls_variant() {
100 TlsVariant::Variant1 => {
101 if self.tls_mods.is_empty() {
103 self.offset = 16;
105 if !(self.offset as *const u8).is_aligned_to(tm.template_align) {
108 let ptr = self.offset as *const u8;
109 self.offset += ptr.align_offset(tm.template_align);
110 }
111 tm.offset = Some(self.offset);
112
113 self.offset += tm.template_memsz;
115 } else {
116 if !(self.offset as *const u8).is_aligned_to(tm.template_align) {
118 let ptr = self.offset as *const u8;
119 self.offset += ptr.align_offset(tm.template_align);
120 }
121
122 tm.offset = Some(self.offset);
124
125 self.offset += tm.template_memsz;
127 }
128 TlsModId((self.tls_mods.len() + 1) as u64, tm.offset.unwrap())
130 }
131 TlsVariant::Variant2 => {
132 if self.tls_mods.is_empty() {
136 self.offset = tm.template_memsz
138 + ((tm.template_addr + tm.template_memsz).overflowing_neg().0
139 & (tm.template_align - 1));
140 } else {
141 self.offset += tm.template_memsz + tm.template_align - 1;
143 self.offset -= (self.offset + tm.template_addr) & (tm.template_align - 1);
144 }
145 tm.offset = Some(self.offset);
146 TlsModId((self.tls_mods.len() + 1) as u64, self.offset)
148 }
149 };
150
151 tm.id = Some(id);
152 self.tls_mods.push(tm);
153 id
154 }
155
156 pub(crate) fn allocate<T>(
157 &self,
158 _comp: &Compartment,
159 alloc_base: NonNull<u8>,
160 tcb: T,
161 ) -> Result<TlsRegion, DynlinkError> {
162 let layout = self
164 .allocation_layout::<T>()
165 .map_err(|err| DynlinkErrorKind::LayoutError { err })?;
166
167 let thread_pointer = match get_tls_variant() {
169 TlsVariant::Variant1 => {
170 let mut base = usize::from(alloc_base.addr()) + size_of::<Tcb<T>>() - 16;
171 base += base & (layout.align() - 1);
172 NonNull::new(base as *mut u8).unwrap()
173 }
174 TlsVariant::Variant2 => {
175 let mut base = usize::from(alloc_base.addr()) + layout.size() - size_of::<Tcb<T>>();
176 base -= base & (layout.align() - 1);
178 NonNull::new(base as *mut u8).unwrap()
179 }
180 };
181
182 let module_start = match get_tls_variant() {
183 TlsVariant::Variant1 => {
184 let temp = unsafe { thread_pointer.as_ptr().add(16) };
186 let padding_after_tcb = temp.align_offset(self.tls_mods[0].template_align);
188 NonNull::new(unsafe { temp.add(padding_after_tcb) }).unwrap()
189 }
190 TlsVariant::Variant2 => thread_pointer,
191 };
192
193 let dtv_ptr = match get_tls_variant() {
195 TlsVariant::Variant1 => {
196 let after_modules = unsafe { module_start.as_ptr().add(self.offset) };
199 let align_padding = after_modules.align_offset(align_of::<usize>());
200 let dtv_addr = unsafe { after_modules.add(align_padding).cast::<usize>() };
201 NonNull::new(dtv_addr).unwrap()
202 }
203 TlsVariant::Variant2 => alloc_base.cast(),
204 };
205
206 let tls_region = TlsRegion {
207 gen: self.gen,
208 module_top: module_start,
209 thread_pointer,
210 dtv: dtv_ptr,
211 num_dtv_entries: self.dtv_len(),
212 alloc_base,
213 layout,
214 };
215
216 for tm in &self.tls_mods {
218 if !tm.is_static {
219 error!("non-static TLS modules are not supported");
220 continue;
221 }
222 tls_region.copy_in_module(tm);
223 tls_region.set_dtv_entry(tm);
224 }
225
226 trace!("setting dtv[0] to gen_count {}", self.gen);
228 unsafe { *tls_region.dtv.as_ptr() = self.gen as usize };
229
230 unsafe { (tls_region.get_thread_control_block()).write(Tcb::new(&tls_region, tcb)) };
232
233 Ok(tls_region)
234 }
235
236 fn dtv_len(&self) -> usize {
237 self.tls_mods.len() + 1
238 }
239
240 pub(crate) fn allocation_layout<T>(&self) -> Result<Layout, std::alloc::LayoutError> {
241 let align = std::cmp::max(self.max_align, align_of::<Tcb<T>>()).next_power_of_two();
243 let region_size = self.alloc_size_mods + align * (self.tls_mods.len() + 2);
247 let dtv_size = self.dtv_len() * size_of::<usize>();
248 let size = region_size + size_of::<Tcb<T>>() + dtv_size;
250 Layout::from_size_align(size, align)
251 }
252}
253
254impl<T> Tcb<T> {
255 pub(crate) fn new(tls_region: &TlsRegion, tcb_data: T) -> Self {
256 let self_ptr = unsafe { tls_region.get_thread_control_block() };
257 Self {
258 self_ptr,
259 dtv: tls_region.dtv.as_ptr(),
260 dtv_len: tls_region.num_dtv_entries,
261 runtime_data: tcb_data,
262 }
263 }
264
265 pub fn get_addr(&self, index: &TlsIndex) -> Option<*mut u8> {
266 unsafe {
267 let slice = core::slice::from_raw_parts(self.dtv, self.dtv_len);
268 Some((slice.get(index.mod_id)? + index.offset) as *mut _)
269 }
270 }
271}
272
273#[derive(Debug, Clone, Copy)]
274pub struct TlsModId(u64, usize);
275
276impl TlsModId {
277 pub fn tls_id(&self) -> u64 {
278 self.0
279 }
280
281 pub fn offset(&self) -> usize {
282 self.1
283 }
284}
285
286#[repr(C)]
287#[derive(Debug)]
288pub struct TlsRegion {
289 pub gen: u64,
290 pub layout: Layout,
291 pub alloc_base: NonNull<u8>,
292 pub thread_pointer: NonNull<u8>,
293 pub dtv: NonNull<usize>,
294 pub num_dtv_entries: usize,
295 pub module_top: NonNull<u8>,
296}
297
298impl TlsRegion {
299 pub fn alloc_base(&self) -> *mut u8 {
300 self.alloc_base.as_ptr()
301 }
302
303 pub fn alloc_layout(&self) -> Layout {
304 self.layout
305 }
306
307 pub fn get_thread_pointer_value(&self) -> usize {
308 self.thread_pointer.as_ptr() as usize
309 }
310
311 pub(crate) fn set_dtv_entry(&self, tm: &TlsModule) {
312 let dtv_slice =
313 unsafe { core::slice::from_raw_parts_mut(self.dtv.as_ptr(), self.num_dtv_entries) };
314 let dtv_idx = tm.id.as_ref().unwrap().tls_id() as usize;
315 let dtv_val = match get_tls_variant() {
316 TlsVariant::Variant1 => unsafe { self.thread_pointer.as_ptr().add(tm.offset.unwrap()) },
317 TlsVariant::Variant2 => unsafe { self.module_top.as_ptr().sub(tm.offset.unwrap()) },
318 };
319 trace!("setting dtv entry {} <= {:p}", dtv_idx, dtv_val);
320 dtv_slice[dtv_idx] = dtv_val as usize;
321 }
322
323 pub(crate) fn copy_in_module(&self, tm: &TlsModule) -> usize {
324 unsafe {
325 let start = match get_tls_variant() {
326 TlsVariant::Variant1 => self.thread_pointer.as_ptr().add(tm.offset.unwrap()),
327 TlsVariant::Variant2 => self.module_top.as_ptr().sub(tm.offset.unwrap()),
328 };
329 let src = tm.template_addr as *const u8;
330 trace!(
331 "copy in static region {:p} => {:p} (filesz={}, memsz={})",
332 src,
333 start,
334 tm.template_filesz,
335 tm.template_memsz
336 );
337 start.copy_from_nonoverlapping(src, tm.template_filesz);
338 start as usize
339 }
340 }
341}
342
343pub use crate::arch::get_current_thread_control_block;