use std::{
alloc::GlobalAlloc,
cell::UnsafeCell,
collections::HashMap,
panic::catch_unwind,
sync::{
atomic::{AtomicU32, Ordering},
Mutex,
},
};
use dynlink::tls::Tcb;
use monitor_api::TlsTemplateInfo;
use tracing::trace;
use twizzler_runtime_api::CoreRuntime;
use crate::{preinit_println, runtime::OUR_RUNTIME};
const THREAD_STARTED: u32 = 1;
pub struct RuntimeThreadControl {
internal_lock: AtomicU32,
flags: AtomicU32,
id: UnsafeCell<u32>,
}
impl Default for RuntimeThreadControl {
fn default() -> Self {
Self::new(0)
}
}
impl RuntimeThreadControl {
pub const fn new(id: u32) -> Self {
Self {
internal_lock: AtomicU32::new(0),
flags: AtomicU32::new(0),
id: UnsafeCell::new(id),
}
}
fn write_lock(&self) {
loop {
let old = self.internal_lock.fetch_or(1, Ordering::Acquire);
if old == 0 {
break;
}
}
}
fn write_unlock(&self) {
self.internal_lock.fetch_and(!1, Ordering::Release);
}
fn read_lock(&self) {
loop {
let old = self.internal_lock.fetch_add(2, Ordering::Acquire);
if old > i32::MAX as u32 {
OUR_RUNTIME.abort();
}
if old & 1 == 0 {
break;
}
}
}
fn read_unlock(&self) {
self.internal_lock.fetch_sub(2, Ordering::Release);
}
pub fn set_id(&self, id: u32) {
self.write_lock();
unsafe {
*self.id.get().as_mut().unwrap() = id;
}
self.write_unlock();
}
pub fn id(&self) -> u32 {
self.read_lock();
let id = unsafe { *self.id.get().as_ref().unwrap() };
self.read_unlock();
id
}
}
pub(super) fn with_current_thread<R, F: FnOnce(&RuntimeThreadControl) -> R>(f: F) -> R {
let tp: &mut Tcb<RuntimeThreadControl> = unsafe {
dynlink::tls::get_current_thread_control_block()
.as_mut()
.unwrap()
};
f(&tp.runtime_data)
}
pub(super) extern "C" fn trampoline(arg: usize) -> ! {
const THREAD_PANIC_CODE: u64 = 101;
let code = catch_unwind(|| {
with_current_thread(|cur| {
cur.flags.fetch_or(THREAD_STARTED, Ordering::SeqCst);
trace!("thread {} started", cur.id());
});
twizzler_abi::syscall::sys_kernel_console_write(
b"alive\n",
twizzler_abi::syscall::KernelConsoleWriteFlags::empty(),
);
let arg = unsafe {
(arg as *const twizzler_runtime_api::ThreadSpawnArgs)
.as_ref()
.unwrap()
};
let entry: extern "C" fn(usize) = unsafe { core::mem::transmute(arg.start) };
(entry)(arg.arg);
0
})
.unwrap_or(THREAD_PANIC_CODE);
twizzler_abi::syscall::sys_thread_exit(code);
}
#[derive(Default)]
pub(crate) struct TlsGenMgr {
map: HashMap<u64, TlsGen>,
}
pub(crate) struct TlsGen {
template: TlsTemplateInfo,
thread_count: usize,
}
unsafe impl Send for TlsGen {}
lazy_static::lazy_static! {
pub(crate) static ref TLS_GEN_MGR: Mutex<TlsGenMgr> = Mutex::new(TlsGenMgr::default());
}
impl TlsGenMgr {
pub fn get_next_tls_info<T>(
&mut self,
mygen: Option<u64>,
new_tcb_data: impl FnOnce() -> T,
) -> Option<*mut Tcb<T>> {
let cc = monitor_api::get_comp_config();
let template = unsafe { cc.get_tls_template().as_ref().unwrap() };
if mygen.is_some_and(|mygen| mygen == template.gen) {
return None;
}
let new = unsafe { OUR_RUNTIME.get_alloc().alloc(template.layout) };
let tlsgen = self.map.entry(template.gen).or_insert_with(|| TlsGen {
template: *template,
thread_count: 0,
});
tlsgen.thread_count += 1;
unsafe {
let tcb = tlsgen.template.init_new_tls_region(new, new_tcb_data());
Some(tcb)
}
}
}