rusqlite/
functions.rs

1//! Create or redefine SQL functions.
2//!
3//! # Example
4//!
5//! Adding a `regexp` function to a connection in which compiled regular
6//! expressions are cached in a `HashMap`. For an alternative implementation
7//! that uses SQLite's [Function Auxiliary Data](https://www.sqlite.org/c3ref/get_auxdata.html) interface
8//! to avoid recompiling regular expressions, see the unit tests for this
9//! module.
10//!
11//! ```rust
12//! use regex::Regex;
13//! use rusqlite::functions::FunctionFlags;
14//! use rusqlite::{Connection, Error, Result};
15//! use std::sync::Arc;
16//! type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
17//!
18//! fn add_regexp_function(db: &Connection) -> Result<()> {
19//!     db.create_scalar_function(
20//!         "regexp",
21//!         2,
22//!         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
23//!         move |ctx| {
24//!             assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
25//!             let regexp: Arc<Regex> = ctx.get_or_create_aux(0, |vr| -> Result<_, BoxError> {
26//!                 Ok(Regex::new(vr.as_str()?)?)
27//!             })?;
28//!             let is_match = {
29//!                 let text = ctx
30//!                     .get_raw(1)
31//!                     .as_str()
32//!                     .map_err(|e| Error::UserFunctionError(e.into()))?;
33//!
34//!                 regexp.is_match(text)
35//!             };
36//!
37//!             Ok(is_match)
38//!         },
39//!     )
40//! }
41//!
42//! fn main() -> Result<()> {
43//!     let db = Connection::open_in_memory()?;
44//!     add_regexp_function(&db)?;
45//!
46//!     let is_match: bool =
47//!         db.query_row("SELECT regexp('[aeiou]*', 'aaaaeeeiii')", [], |row| {
48//!             row.get(0)
49//!         })?;
50//!
51//!     assert!(is_match);
52//!     Ok(())
53//! }
54//! ```
55use std::any::Any;
56use std::marker::PhantomData;
57use std::ops::Deref;
58use std::os::raw::{c_int, c_void};
59use std::panic::{catch_unwind, RefUnwindSafe, UnwindSafe};
60use std::ptr;
61use std::slice;
62use std::sync::Arc;
63
64use crate::ffi;
65use crate::ffi::sqlite3_context;
66use crate::ffi::sqlite3_value;
67
68use crate::context::set_result;
69use crate::types::{FromSql, FromSqlError, ToSql, ToSqlOutput, ValueRef};
70
71use crate::{str_to_cstring, Connection, Error, InnerConnection, Result};
72
73unsafe fn report_error(ctx: *mut sqlite3_context, err: &Error) {
74    if let Error::SqliteFailure(ref err, ref s) = *err {
75        ffi::sqlite3_result_error_code(ctx, err.extended_code);
76        if let Some(Ok(cstr)) = s.as_ref().map(|s| str_to_cstring(s)) {
77            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
78        }
79    } else {
80        ffi::sqlite3_result_error_code(ctx, ffi::SQLITE_CONSTRAINT_FUNCTION);
81        if let Ok(cstr) = str_to_cstring(&err.to_string()) {
82            ffi::sqlite3_result_error(ctx, cstr.as_ptr(), -1);
83        }
84    }
85}
86
87unsafe extern "C" fn free_boxed_value<T>(p: *mut c_void) {
88    drop(Box::from_raw(p.cast::<T>()));
89}
90
91/// Context is a wrapper for the SQLite function
92/// evaluation context.
93pub struct Context<'a> {
94    ctx: *mut sqlite3_context,
95    args: &'a [*mut sqlite3_value],
96}
97
98impl Context<'_> {
99    /// Returns the number of arguments to the function.
100    #[inline]
101    #[must_use]
102    pub fn len(&self) -> usize {
103        self.args.len()
104    }
105
106    /// Returns `true` when there is no argument.
107    #[inline]
108    #[must_use]
109    pub fn is_empty(&self) -> bool {
110        self.args.is_empty()
111    }
112
113    /// Returns the `idx`th argument as a `T`.
114    ///
115    /// # Failure
116    ///
117    /// Will panic if `idx` is greater than or equal to
118    /// [`self.len()`](Context::len).
119    ///
120    /// Will return Err if the underlying SQLite type cannot be converted to a
121    /// `T`.
122    pub fn get<T: FromSql>(&self, idx: usize) -> Result<T> {
123        let arg = self.args[idx];
124        let value = unsafe { ValueRef::from_value(arg) };
125        FromSql::column_result(value).map_err(|err| match err {
126            FromSqlError::InvalidType => {
127                Error::InvalidFunctionParameterType(idx, value.data_type())
128            }
129            FromSqlError::OutOfRange(i) => Error::IntegralValueOutOfRange(idx, i),
130            FromSqlError::Other(err) => {
131                Error::FromSqlConversionFailure(idx, value.data_type(), err)
132            }
133            FromSqlError::InvalidBlobSize { .. } => {
134                Error::FromSqlConversionFailure(idx, value.data_type(), Box::new(err))
135            }
136        })
137    }
138
139    /// Returns the `idx`th argument as a `ValueRef`.
140    ///
141    /// # Failure
142    ///
143    /// Will panic if `idx` is greater than or equal to
144    /// [`self.len()`](Context::len).
145    #[inline]
146    #[must_use]
147    pub fn get_raw(&self, idx: usize) -> ValueRef<'_> {
148        let arg = self.args[idx];
149        unsafe { ValueRef::from_value(arg) }
150    }
151
152    /// Returns the `idx`th argument as a `SqlFnArg`.
153    /// To be used when the SQL function result is one of its arguments.
154    #[inline]
155    #[must_use]
156    pub fn get_arg(&self, idx: usize) -> SqlFnArg {
157        assert!(idx < self.len());
158        SqlFnArg { idx }
159    }
160
161    /// Returns the subtype of `idx`th argument.
162    ///
163    /// # Failure
164    ///
165    /// Will panic if `idx` is greater than or equal to
166    /// [`self.len()`](Context::len).
167    pub fn get_subtype(&self, idx: usize) -> std::os::raw::c_uint {
168        let arg = self.args[idx];
169        unsafe { ffi::sqlite3_value_subtype(arg) }
170    }
171
172    /// Fetch or insert the auxiliary data associated with a particular
173    /// parameter. This is intended to be an easier-to-use way of fetching it
174    /// compared to calling [`get_aux`](Context::get_aux) and
175    /// [`set_aux`](Context::set_aux) separately.
176    ///
177    /// See `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
178    /// this feature, or the unit tests of this module for an example.
179    ///
180    /// # Failure
181    ///
182    /// Will panic if `arg` is greater than or equal to
183    /// [`self.len()`](Context::len).
184    pub fn get_or_create_aux<T, E, F>(&self, arg: c_int, func: F) -> Result<Arc<T>>
185    where
186        T: Send + Sync + 'static,
187        E: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
188        F: FnOnce(ValueRef<'_>) -> Result<T, E>,
189    {
190        if let Some(v) = self.get_aux(arg)? {
191            Ok(v)
192        } else {
193            let vr = self.get_raw(arg as usize);
194            self.set_aux(
195                arg,
196                func(vr).map_err(|e| Error::UserFunctionError(e.into()))?,
197            )
198        }
199    }
200
201    /// Sets the auxiliary data associated with a particular parameter. See
202    /// `https://www.sqlite.org/c3ref/get_auxdata.html` for a discussion of
203    /// this feature, or the unit tests of this module for an example.
204    ///
205    /// # Failure
206    ///
207    /// Will panic if `arg` is greater than or equal to
208    /// [`self.len()`](Context::len).
209    pub fn set_aux<T: Send + Sync + 'static>(&self, arg: c_int, value: T) -> Result<Arc<T>> {
210        assert!(arg < self.len() as i32);
211        let orig: Arc<T> = Arc::new(value);
212        let inner: AuxInner = orig.clone();
213        let outer = Box::new(inner);
214        let raw: *mut AuxInner = Box::into_raw(outer);
215        unsafe {
216            ffi::sqlite3_set_auxdata(
217                self.ctx,
218                arg,
219                raw.cast(),
220                Some(free_boxed_value::<AuxInner>),
221            );
222        };
223        Ok(orig)
224    }
225
226    /// Gets the auxiliary data that was associated with a given parameter via
227    /// [`set_aux`](Context::set_aux). Returns `Ok(None)` if no data has been
228    /// associated, and Ok(Some(v)) if it has. Returns an error if the
229    /// requested type does not match.
230    ///
231    /// # Failure
232    ///
233    /// Will panic if `arg` is greater than or equal to
234    /// [`self.len()`](Context::len).
235    pub fn get_aux<T: Send + Sync + 'static>(&self, arg: c_int) -> Result<Option<Arc<T>>> {
236        assert!(arg < self.len() as i32);
237        let p = unsafe { ffi::sqlite3_get_auxdata(self.ctx, arg) as *const AuxInner };
238        if p.is_null() {
239            Ok(None)
240        } else {
241            let v: AuxInner = AuxInner::clone(unsafe { &*p });
242            v.downcast::<T>()
243                .map(Some)
244                .map_err(|_| Error::GetAuxWrongType)
245        }
246    }
247
248    /// Get the db connection handle via [sqlite3_context_db_handle](https://www.sqlite.org/c3ref/context_db_handle.html)
249    ///
250    /// # Safety
251    ///
252    /// This function is marked unsafe because there is a potential for other
253    /// references to the connection to be sent across threads, [see this comment](https://github.com/rusqlite/rusqlite/issues/643#issuecomment-640181213).
254    pub unsafe fn get_connection(&self) -> Result<ConnectionRef<'_>> {
255        let handle = ffi::sqlite3_context_db_handle(self.ctx);
256        Ok(ConnectionRef {
257            conn: Connection::from_handle(handle)?,
258            phantom: PhantomData,
259        })
260    }
261}
262
263/// A reference to a connection handle with a lifetime bound to something.
264pub struct ConnectionRef<'ctx> {
265    // comes from Connection::from_handle(sqlite3_context_db_handle(...))
266    // and is non-owning
267    conn: Connection,
268    phantom: PhantomData<&'ctx Context<'ctx>>,
269}
270
271impl Deref for ConnectionRef<'_> {
272    type Target = Connection;
273
274    #[inline]
275    fn deref(&self) -> &Connection {
276        &self.conn
277    }
278}
279
280type AuxInner = Arc<dyn Any + Send + Sync + 'static>;
281
282/// Subtype of an SQL function
283pub type SubType = Option<std::os::raw::c_uint>;
284
285/// Result of an SQL function
286pub trait SqlFnOutput {
287    /// Converts Rust value to SQLite value with an optional subtype
288    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)>;
289}
290
291impl<T: ToSql> SqlFnOutput for T {
292    #[inline]
293    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
294        ToSql::to_sql(self).map(|o| (o, None))
295    }
296}
297
298impl<T: ToSql> SqlFnOutput for (T, SubType) {
299    fn to_sql(&self) -> Result<(ToSqlOutput<'_>, SubType)> {
300        ToSql::to_sql(&self.0).map(|o| (o, self.1))
301    }
302}
303
304/// n-th arg of an SQL scalar function
305pub struct SqlFnArg {
306    idx: usize,
307}
308impl ToSql for SqlFnArg {
309    fn to_sql(&self) -> Result<ToSqlOutput<'_>> {
310        Ok(ToSqlOutput::Arg(self.idx))
311    }
312}
313
314unsafe fn sql_result<T: SqlFnOutput>(
315    ctx: *mut sqlite3_context,
316    args: &[*mut sqlite3_value],
317    r: Result<T>,
318) {
319    let t = r.as_ref().map(SqlFnOutput::to_sql);
320
321    match t {
322        Ok(Ok((ref value, sub_type))) => {
323            set_result(ctx, args, value);
324            if let Some(sub_type) = sub_type {
325                ffi::sqlite3_result_subtype(ctx, sub_type);
326            }
327        }
328        Ok(Err(err)) => report_error(ctx, &err),
329        Err(err) => report_error(ctx, err),
330    };
331}
332
333/// Aggregate is the callback interface for user-defined
334/// aggregate function.
335///
336/// `A` is the type of the aggregation context and `T` is the type of the final
337/// result. Implementations should be stateless.
338pub trait Aggregate<A, T>
339where
340    A: RefUnwindSafe + UnwindSafe,
341    T: SqlFnOutput,
342{
343    /// Initializes the aggregation context. Will be called prior to the first
344    /// call to [`step()`](Aggregate::step) to set up the context for an
345    /// invocation of the function. (Note: `init()` will not be called if
346    /// there are no rows.)
347    fn init(&self, ctx: &mut Context<'_>) -> Result<A>;
348
349    /// "step" function called once for each row in an aggregate group. May be
350    /// called 0 times if there are no rows.
351    fn step(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
352
353    /// Computes and returns the final result. Will be called exactly once for
354    /// each invocation of the function. If [`step()`](Aggregate::step) was
355    /// called at least once, will be given `Some(A)` (the same `A` as was
356    /// created by [`init`](Aggregate::init) and given to
357    /// [`step`](Aggregate::step)); if [`step()`](Aggregate::step) was not
358    /// called (because the function is running against 0 rows), will be
359    /// given `None`.
360    ///
361    /// The passed context will have no arguments.
362    fn finalize(&self, ctx: &mut Context<'_>, acc: Option<A>) -> Result<T>;
363}
364
365/// `WindowAggregate` is the callback interface for
366/// user-defined aggregate window function.
367#[cfg(feature = "window")]
368#[cfg_attr(docsrs, doc(cfg(feature = "window")))]
369pub trait WindowAggregate<A, T>: Aggregate<A, T>
370where
371    A: RefUnwindSafe + UnwindSafe,
372    T: SqlFnOutput,
373{
374    /// Returns the current value of the aggregate. Unlike xFinal, the
375    /// implementation should not delete any context.
376    fn value(&self, acc: Option<&mut A>) -> Result<T>;
377
378    /// Removes a row from the current window.
379    fn inverse(&self, ctx: &mut Context<'_>, acc: &mut A) -> Result<()>;
380}
381
382bitflags::bitflags! {
383    /// Function Flags.
384    /// See [sqlite3_create_function](https://sqlite.org/c3ref/create_function.html)
385    /// and [Function Flags](https://sqlite.org/c3ref/c_deterministic.html) for details.
386    #[derive(Clone, Copy, Debug)]
387    #[repr(C)]
388    pub struct FunctionFlags: c_int {
389        /// Specifies UTF-8 as the text encoding this SQL function prefers for its parameters.
390        const SQLITE_UTF8     = ffi::SQLITE_UTF8;
391        /// Specifies UTF-16 using little-endian byte order as the text encoding this SQL function prefers for its parameters.
392        const SQLITE_UTF16LE  = ffi::SQLITE_UTF16LE;
393        /// Specifies UTF-16 using big-endian byte order as the text encoding this SQL function prefers for its parameters.
394        const SQLITE_UTF16BE  = ffi::SQLITE_UTF16BE;
395        /// Specifies UTF-16 using native byte order as the text encoding this SQL function prefers for its parameters.
396        const SQLITE_UTF16    = ffi::SQLITE_UTF16;
397        /// Means that the function always gives the same output when the input parameters are the same.
398        const SQLITE_DETERMINISTIC = ffi::SQLITE_DETERMINISTIC; // 3.8.3
399        /// Means that the function may only be invoked from top-level SQL.
400        const SQLITE_DIRECTONLY    = 0x0000_0008_0000; // 3.30.0
401        /// Indicates to SQLite that a function may call `sqlite3_value_subtype()` to inspect the subtypes of its arguments.
402        const SQLITE_SUBTYPE       = 0x0000_0010_0000; // 3.30.0
403        /// Means that the function is unlikely to cause problems even if misused.
404        const SQLITE_INNOCUOUS     = 0x0000_0020_0000; // 3.31.0
405        /// Indicates to SQLite that a function might call `sqlite3_result_subtype()` to cause a subtype to be associated with its result.
406        const SQLITE_RESULT_SUBTYPE     = 0x0000_0100_0000; // 3.45.0
407        /// Indicates that the function is an aggregate that internally orders the values provided to the first argument.
408        const SQLITE_SELFORDER1 = 0x0000_0200_0000; // 3.47.0
409    }
410}
411
412impl Default for FunctionFlags {
413    #[inline]
414    fn default() -> Self {
415        Self::SQLITE_UTF8
416    }
417}
418
419impl Connection {
420    /// Attach a user-defined scalar function to
421    /// this database connection.
422    ///
423    /// `fn_name` is the name the function will be accessible from SQL.
424    /// `n_arg` is the number of arguments to the function. Use `-1` for a
425    /// variable number. If the function always returns the same value
426    /// given the same input, `deterministic` should be `true`.
427    ///
428    /// The function will remain available until the connection is closed or
429    /// until it is explicitly removed via
430    /// [`remove_function`](Connection::remove_function).
431    ///
432    /// # Example
433    ///
434    /// ```rust
435    /// # use rusqlite::{Connection, Result};
436    /// # use rusqlite::functions::FunctionFlags;
437    /// fn scalar_function_example(db: Connection) -> Result<()> {
438    ///     db.create_scalar_function(
439    ///         "halve",
440    ///         1,
441    ///         FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
442    ///         |ctx| {
443    ///             let value = ctx.get::<f64>(0)?;
444    ///             Ok(value / 2f64)
445    ///         },
446    ///     )?;
447    ///
448    ///     let six_halved: f64 = db.query_row("SELECT halve(6)", [], |r| r.get(0))?;
449    ///     assert_eq!(six_halved, 3f64);
450    ///     Ok(())
451    /// }
452    /// ```
453    ///
454    /// # Failure
455    ///
456    /// Will return Err if the function could not be attached to the connection.
457    #[inline]
458    pub fn create_scalar_function<F, T>(
459        &self,
460        fn_name: &str,
461        n_arg: c_int,
462        flags: FunctionFlags,
463        x_func: F,
464    ) -> Result<()>
465    where
466        F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
467        T: SqlFnOutput,
468    {
469        self.db
470            .borrow_mut()
471            .create_scalar_function(fn_name, n_arg, flags, x_func)
472    }
473
474    /// Attach a user-defined aggregate function to this
475    /// database connection.
476    ///
477    /// # Failure
478    ///
479    /// Will return Err if the function could not be attached to the connection.
480    #[inline]
481    pub fn create_aggregate_function<A, D, T>(
482        &self,
483        fn_name: &str,
484        n_arg: c_int,
485        flags: FunctionFlags,
486        aggr: D,
487    ) -> Result<()>
488    where
489        A: RefUnwindSafe + UnwindSafe,
490        D: Aggregate<A, T> + 'static,
491        T: SqlFnOutput,
492    {
493        self.db
494            .borrow_mut()
495            .create_aggregate_function(fn_name, n_arg, flags, aggr)
496    }
497
498    /// Attach a user-defined aggregate window function to
499    /// this database connection.
500    ///
501    /// See `https://sqlite.org/windowfunctions.html#udfwinfunc` for more
502    /// information.
503    #[cfg(feature = "window")]
504    #[cfg_attr(docsrs, doc(cfg(feature = "window")))]
505    #[inline]
506    pub fn create_window_function<A, W, T>(
507        &self,
508        fn_name: &str,
509        n_arg: c_int,
510        flags: FunctionFlags,
511        aggr: W,
512    ) -> Result<()>
513    where
514        A: RefUnwindSafe + UnwindSafe,
515        W: WindowAggregate<A, T> + 'static,
516        T: SqlFnOutput,
517    {
518        self.db
519            .borrow_mut()
520            .create_window_function(fn_name, n_arg, flags, aggr)
521    }
522
523    /// Removes a user-defined function from this
524    /// database connection.
525    ///
526    /// `fn_name` and `n_arg` should match the name and number of arguments
527    /// given to [`create_scalar_function`](Connection::create_scalar_function)
528    /// or [`create_aggregate_function`](Connection::create_aggregate_function).
529    ///
530    /// # Failure
531    ///
532    /// Will return Err if the function could not be removed.
533    #[inline]
534    pub fn remove_function(&self, fn_name: &str, n_arg: c_int) -> Result<()> {
535        self.db.borrow_mut().remove_function(fn_name, n_arg)
536    }
537}
538
539impl InnerConnection {
540    /// ```compile_fail
541    /// use rusqlite::{functions::FunctionFlags, Connection, Result};
542    /// fn main() -> Result<()> {
543    ///     let db = Connection::open_in_memory()?;
544    ///     {
545    ///         let mut called = std::sync::atomic::AtomicBool::new(false);
546    ///         db.create_scalar_function(
547    ///             "test",
548    ///             0,
549    ///             FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
550    ///             |_| {
551    ///                 called.store(true, std::sync::atomic::Ordering::Relaxed);
552    ///                 Ok(true)
553    ///             },
554    ///         );
555    ///     }
556    ///     let result: Result<bool> = db.query_row("SELECT test()", [], |r| r.get(0));
557    ///     assert!(result?);
558    ///     Ok(())
559    /// }
560    /// ```
561    fn create_scalar_function<F, T>(
562        &mut self,
563        fn_name: &str,
564        n_arg: c_int,
565        flags: FunctionFlags,
566        x_func: F,
567    ) -> Result<()>
568    where
569        F: Fn(&Context<'_>) -> Result<T> + Send + 'static,
570        T: SqlFnOutput,
571    {
572        unsafe extern "C" fn call_boxed_closure<F, T>(
573            ctx: *mut sqlite3_context,
574            argc: c_int,
575            argv: *mut *mut sqlite3_value,
576        ) where
577            F: Fn(&Context<'_>) -> Result<T>,
578            T: SqlFnOutput,
579        {
580            let args = slice::from_raw_parts(argv, argc as usize);
581            let r = catch_unwind(|| {
582                let boxed_f: *const F = ffi::sqlite3_user_data(ctx).cast::<F>();
583                assert!(!boxed_f.is_null(), "Internal error - null function pointer");
584                let ctx = Context { ctx, args };
585                (*boxed_f)(&ctx)
586            });
587            let t = match r {
588                Err(_) => {
589                    report_error(ctx, &Error::UnwindingPanic);
590                    return;
591                }
592                Ok(r) => r,
593            };
594            sql_result(ctx, args, t);
595        }
596
597        let boxed_f: *mut F = Box::into_raw(Box::new(x_func));
598        let c_name = str_to_cstring(fn_name)?;
599        let r = unsafe {
600            ffi::sqlite3_create_function_v2(
601                self.db(),
602                c_name.as_ptr(),
603                n_arg,
604                flags.bits(),
605                boxed_f.cast::<c_void>(),
606                Some(call_boxed_closure::<F, T>),
607                None,
608                None,
609                Some(free_boxed_value::<F>),
610            )
611        };
612        self.decode_result(r)
613    }
614
615    fn create_aggregate_function<A, D, T>(
616        &mut self,
617        fn_name: &str,
618        n_arg: c_int,
619        flags: FunctionFlags,
620        aggr: D,
621    ) -> Result<()>
622    where
623        A: RefUnwindSafe + UnwindSafe,
624        D: Aggregate<A, T> + 'static,
625        T: SqlFnOutput,
626    {
627        let boxed_aggr: *mut D = Box::into_raw(Box::new(aggr));
628        let c_name = str_to_cstring(fn_name)?;
629        let r = unsafe {
630            ffi::sqlite3_create_function_v2(
631                self.db(),
632                c_name.as_ptr(),
633                n_arg,
634                flags.bits(),
635                boxed_aggr.cast::<c_void>(),
636                None,
637                Some(call_boxed_step::<A, D, T>),
638                Some(call_boxed_final::<A, D, T>),
639                Some(free_boxed_value::<D>),
640            )
641        };
642        self.decode_result(r)
643    }
644
645    #[cfg(feature = "window")]
646    fn create_window_function<A, W, T>(
647        &mut self,
648        fn_name: &str,
649        n_arg: c_int,
650        flags: FunctionFlags,
651        aggr: W,
652    ) -> Result<()>
653    where
654        A: RefUnwindSafe + UnwindSafe,
655        W: WindowAggregate<A, T> + 'static,
656        T: SqlFnOutput,
657    {
658        let boxed_aggr: *mut W = Box::into_raw(Box::new(aggr));
659        let c_name = str_to_cstring(fn_name)?;
660        let r = unsafe {
661            ffi::sqlite3_create_window_function(
662                self.db(),
663                c_name.as_ptr(),
664                n_arg,
665                flags.bits(),
666                boxed_aggr.cast::<c_void>(),
667                Some(call_boxed_step::<A, W, T>),
668                Some(call_boxed_final::<A, W, T>),
669                Some(call_boxed_value::<A, W, T>),
670                Some(call_boxed_inverse::<A, W, T>),
671                Some(free_boxed_value::<W>),
672            )
673        };
674        self.decode_result(r)
675    }
676
677    fn remove_function(&mut self, fn_name: &str, n_arg: c_int) -> Result<()> {
678        let c_name = str_to_cstring(fn_name)?;
679        let r = unsafe {
680            ffi::sqlite3_create_function_v2(
681                self.db(),
682                c_name.as_ptr(),
683                n_arg,
684                ffi::SQLITE_UTF8,
685                ptr::null_mut(),
686                None,
687                None,
688                None,
689                None,
690            )
691        };
692        self.decode_result(r)
693    }
694}
695
696unsafe fn aggregate_context<A>(ctx: *mut sqlite3_context, bytes: usize) -> Option<*mut *mut A> {
697    let pac = ffi::sqlite3_aggregate_context(ctx, bytes as c_int) as *mut *mut A;
698    if pac.is_null() {
699        return None;
700    }
701    Some(pac)
702}
703
704unsafe extern "C" fn call_boxed_step<A, D, T>(
705    ctx: *mut sqlite3_context,
706    argc: c_int,
707    argv: *mut *mut sqlite3_value,
708) where
709    A: RefUnwindSafe + UnwindSafe,
710    D: Aggregate<A, T>,
711    T: SqlFnOutput,
712{
713    let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
714        ffi::sqlite3_result_error_nomem(ctx);
715        return;
716    };
717
718    let r = catch_unwind(|| {
719        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
720        assert!(
721            !boxed_aggr.is_null(),
722            "Internal error - null aggregate pointer"
723        );
724        let mut ctx = Context {
725            ctx,
726            args: slice::from_raw_parts(argv, argc as usize),
727        };
728
729        #[expect(clippy::unnecessary_cast)]
730        if (*pac as *mut A).is_null() {
731            *pac = Box::into_raw(Box::new((*boxed_aggr).init(&mut ctx)?));
732        }
733
734        (*boxed_aggr).step(&mut ctx, &mut **pac)
735    });
736    let r = match r {
737        Err(_) => {
738            report_error(ctx, &Error::UnwindingPanic);
739            return;
740        }
741        Ok(r) => r,
742    };
743    match r {
744        Ok(_) => {}
745        Err(err) => report_error(ctx, &err),
746    };
747}
748
749#[cfg(feature = "window")]
750unsafe extern "C" fn call_boxed_inverse<A, W, T>(
751    ctx: *mut sqlite3_context,
752    argc: c_int,
753    argv: *mut *mut sqlite3_value,
754) where
755    A: RefUnwindSafe + UnwindSafe,
756    W: WindowAggregate<A, T>,
757    T: SqlFnOutput,
758{
759    let Some(pac) = aggregate_context(ctx, size_of::<*mut A>()) else {
760        ffi::sqlite3_result_error_nomem(ctx);
761        return;
762    };
763
764    let r = catch_unwind(|| {
765        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
766        assert!(
767            !boxed_aggr.is_null(),
768            "Internal error - null aggregate pointer"
769        );
770        let mut ctx = Context {
771            ctx,
772            args: slice::from_raw_parts(argv, argc as usize),
773        };
774        (*boxed_aggr).inverse(&mut ctx, &mut **pac)
775    });
776    let r = match r {
777        Err(_) => {
778            report_error(ctx, &Error::UnwindingPanic);
779            return;
780        }
781        Ok(r) => r,
782    };
783    match r {
784        Ok(_) => {}
785        Err(err) => report_error(ctx, &err),
786    };
787}
788
789unsafe extern "C" fn call_boxed_final<A, D, T>(ctx: *mut sqlite3_context)
790where
791    A: RefUnwindSafe + UnwindSafe,
792    D: Aggregate<A, T>,
793    T: SqlFnOutput,
794{
795    // Within the xFinal callback, it is customary to set N=0 in calls to
796    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
797    let a: Option<A> = match aggregate_context(ctx, 0) {
798        Some(pac) =>
799        {
800            #[expect(clippy::unnecessary_cast)]
801            if (*pac as *mut A).is_null() {
802                None
803            } else {
804                let a = Box::from_raw(*pac);
805                Some(*a)
806            }
807        }
808        None => None,
809    };
810
811    let r = catch_unwind(|| {
812        let boxed_aggr: *mut D = ffi::sqlite3_user_data(ctx).cast::<D>();
813        assert!(
814            !boxed_aggr.is_null(),
815            "Internal error - null aggregate pointer"
816        );
817        let mut ctx = Context { ctx, args: &mut [] };
818        (*boxed_aggr).finalize(&mut ctx, a)
819    });
820    let t = match r {
821        Err(_) => {
822            report_error(ctx, &Error::UnwindingPanic);
823            return;
824        }
825        Ok(r) => r,
826    };
827    sql_result(ctx, &[], t);
828}
829
830#[cfg(feature = "window")]
831unsafe extern "C" fn call_boxed_value<A, W, T>(ctx: *mut sqlite3_context)
832where
833    A: RefUnwindSafe + UnwindSafe,
834    W: WindowAggregate<A, T>,
835    T: SqlFnOutput,
836{
837    // Within the xValue callback, it is customary to set N=0 in calls to
838    // sqlite3_aggregate_context(C,N) so that no pointless memory allocations occur.
839    let pac = aggregate_context(ctx, 0).filter(|&pac| {
840        #[expect(clippy::unnecessary_cast)]
841        !(*pac as *mut A).is_null()
842    });
843
844    let r = catch_unwind(|| {
845        let boxed_aggr: *mut W = ffi::sqlite3_user_data(ctx).cast::<W>();
846        assert!(
847            !boxed_aggr.is_null(),
848            "Internal error - null aggregate pointer"
849        );
850        (*boxed_aggr).value(pac.map(|pac| &mut **pac))
851    });
852    let t = match r {
853        Err(_) => {
854            report_error(ctx, &Error::UnwindingPanic);
855            return;
856        }
857        Ok(r) => r,
858    };
859    sql_result(ctx, &[], t);
860}
861
862#[cfg(test)]
863mod test {
864    use regex::Regex;
865    use std::os::raw::c_double;
866
867    #[cfg(feature = "window")]
868    use crate::functions::WindowAggregate;
869    use crate::functions::{Aggregate, Context, FunctionFlags, SqlFnArg, SubType};
870    use crate::{Connection, Error, Result};
871
872    fn half(ctx: &Context<'_>) -> Result<c_double> {
873        assert!(!ctx.is_empty());
874        assert_eq!(ctx.len(), 1, "called with unexpected number of arguments");
875        assert!(unsafe {
876            ctx.get_connection()
877                .as_ref()
878                .map(::std::ops::Deref::deref)
879                .is_ok()
880        });
881        let value = ctx.get::<c_double>(0)?;
882        Ok(value / 2f64)
883    }
884
885    #[test]
886    fn test_function_half() -> Result<()> {
887        let db = Connection::open_in_memory()?;
888        db.create_scalar_function(
889            "half",
890            1,
891            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
892            half,
893        )?;
894        let result: f64 = db.one_column("SELECT half(6)")?;
895
896        assert!((3f64 - result).abs() < f64::EPSILON);
897        Ok(())
898    }
899
900    #[test]
901    fn test_remove_function() -> Result<()> {
902        let db = Connection::open_in_memory()?;
903        db.create_scalar_function(
904            "half",
905            1,
906            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
907            half,
908        )?;
909        let result: f64 = db.one_column("SELECT half(6)")?;
910        assert!((3f64 - result).abs() < f64::EPSILON);
911
912        db.remove_function("half", 1)?;
913        let result: Result<f64> = db.one_column("SELECT half(6)");
914        result.unwrap_err();
915        Ok(())
916    }
917
918    // This implementation of a regexp scalar function uses SQLite's auxiliary data
919    // (https://www.sqlite.org/c3ref/get_auxdata.html) to avoid recompiling the regular
920    // expression multiple times within one query.
921    fn regexp_with_auxiliary(ctx: &Context<'_>) -> Result<bool> {
922        assert_eq!(ctx.len(), 2, "called with unexpected number of arguments");
923        type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
924        let regexp: std::sync::Arc<Regex> = ctx
925            .get_or_create_aux(0, |vr| -> Result<_, BoxError> {
926                Ok(Regex::new(vr.as_str()?)?)
927            })?;
928
929        let is_match = {
930            let text = ctx
931                .get_raw(1)
932                .as_str()
933                .map_err(|e| Error::UserFunctionError(e.into()))?;
934
935            regexp.is_match(text)
936        };
937
938        Ok(is_match)
939    }
940
941    #[test]
942    fn test_function_regexp_with_auxiliary() -> Result<()> {
943        let db = Connection::open_in_memory()?;
944        db.execute_batch(
945            "BEGIN;
946             CREATE TABLE foo (x string);
947             INSERT INTO foo VALUES ('lisa');
948             INSERT INTO foo VALUES ('lXsi');
949             INSERT INTO foo VALUES ('lisX');
950             END;",
951        )?;
952        db.create_scalar_function(
953            "regexp",
954            2,
955            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
956            regexp_with_auxiliary,
957        )?;
958
959        let result: bool = db.one_column("SELECT regexp('l.s[aeiouy]', 'lisa')")?;
960
961        assert!(result);
962
963        let result: i64 =
964            db.one_column("SELECT COUNT(*) FROM foo WHERE regexp('l.s[aeiouy]', x) == 1")?;
965
966        assert_eq!(2, result);
967        Ok(())
968    }
969
970    #[test]
971    fn test_varargs_function() -> Result<()> {
972        let db = Connection::open_in_memory()?;
973        db.create_scalar_function(
974            "my_concat",
975            -1,
976            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
977            |ctx| {
978                let mut ret = String::new();
979
980                for idx in 0..ctx.len() {
981                    let s = ctx.get::<String>(idx)?;
982                    ret.push_str(&s);
983                }
984
985                Ok(ret)
986            },
987        )?;
988
989        for &(expected, query) in &[
990            ("", "SELECT my_concat()"),
991            ("onetwo", "SELECT my_concat('one', 'two')"),
992            ("abc", "SELECT my_concat('a', 'b', 'c')"),
993        ] {
994            let result: String = db.one_column(query)?;
995            assert_eq!(expected, result);
996        }
997        Ok(())
998    }
999
1000    #[test]
1001    fn test_get_aux_type_checking() -> Result<()> {
1002        let db = Connection::open_in_memory()?;
1003        db.create_scalar_function("example", 2, FunctionFlags::default(), |ctx| {
1004            if !ctx.get::<bool>(1)? {
1005                ctx.set_aux::<i64>(0, 100)?;
1006            } else {
1007                assert_eq!(ctx.get_aux::<String>(0), Err(Error::GetAuxWrongType));
1008                assert_eq!(*ctx.get_aux::<i64>(0)?.unwrap(), 100);
1009            }
1010            Ok(true)
1011        })?;
1012
1013        let res: bool =
1014            db.one_column("SELECT example(0, i) FROM (SELECT 0 as i UNION SELECT 1)")?;
1015        // Doesn't actually matter, we'll assert in the function if there's a problem.
1016        assert!(res);
1017        Ok(())
1018    }
1019
1020    struct Sum;
1021    struct Count;
1022
1023    impl Aggregate<i64, Option<i64>> for Sum {
1024        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1025            Ok(0)
1026        }
1027
1028        fn step(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1029            *sum += ctx.get::<i64>(0)?;
1030            Ok(())
1031        }
1032
1033        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<Option<i64>> {
1034            Ok(sum)
1035        }
1036    }
1037
1038    impl Aggregate<i64, i64> for Count {
1039        fn init(&self, _: &mut Context<'_>) -> Result<i64> {
1040            Ok(0)
1041        }
1042
1043        fn step(&self, _ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1044            *sum += 1;
1045            Ok(())
1046        }
1047
1048        fn finalize(&self, _: &mut Context<'_>, sum: Option<i64>) -> Result<i64> {
1049            Ok(sum.unwrap_or(0))
1050        }
1051    }
1052
1053    #[test]
1054    fn test_sum() -> Result<()> {
1055        let db = Connection::open_in_memory()?;
1056        db.create_aggregate_function(
1057            "my_sum",
1058            1,
1059            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1060            Sum,
1061        )?;
1062
1063        // sum should return NULL when given no columns (contrast with count below)
1064        let no_result = "SELECT my_sum(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1065        let result: Option<i64> = db.one_column(no_result)?;
1066        assert!(result.is_none());
1067
1068        let single_sum = "SELECT my_sum(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1069        let result: i64 = db.one_column(single_sum)?;
1070        assert_eq!(4, result);
1071
1072        let dual_sum = "SELECT my_sum(i), my_sum(j) FROM (SELECT 2 AS i, 1 AS j UNION ALL SELECT \
1073                        2, 1)";
1074        let result: (i64, i64) = db.query_row(dual_sum, [], |r| Ok((r.get(0)?, r.get(1)?)))?;
1075        assert_eq!((4, 2), result);
1076        Ok(())
1077    }
1078
1079    #[test]
1080    fn test_count() -> Result<()> {
1081        let db = Connection::open_in_memory()?;
1082        db.create_aggregate_function(
1083            "my_count",
1084            -1,
1085            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1086            Count,
1087        )?;
1088
1089        // count should return 0 when given no columns (contrast with sum above)
1090        let no_result = "SELECT my_count(i) FROM (SELECT 2 AS i WHERE 1 <> 1)";
1091        let result: i64 = db.one_column(no_result)?;
1092        assert_eq!(result, 0);
1093
1094        let single_sum = "SELECT my_count(i) FROM (SELECT 2 AS i UNION ALL SELECT 2)";
1095        let result: i64 = db.one_column(single_sum)?;
1096        assert_eq!(2, result);
1097        Ok(())
1098    }
1099
1100    #[cfg(feature = "window")]
1101    impl WindowAggregate<i64, Option<i64>> for Sum {
1102        fn inverse(&self, ctx: &mut Context<'_>, sum: &mut i64) -> Result<()> {
1103            *sum -= ctx.get::<i64>(0)?;
1104            Ok(())
1105        }
1106
1107        fn value(&self, sum: Option<&mut i64>) -> Result<Option<i64>> {
1108            Ok(sum.copied())
1109        }
1110    }
1111
1112    #[test]
1113    #[cfg(feature = "window")]
1114    fn test_window() -> Result<()> {
1115        use fallible_iterator::FallibleIterator;
1116
1117        let db = Connection::open_in_memory()?;
1118        db.create_window_function(
1119            "sumint",
1120            1,
1121            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_DETERMINISTIC,
1122            Sum,
1123        )?;
1124        db.execute_batch(
1125            "CREATE TABLE t3(x, y);
1126             INSERT INTO t3 VALUES('a', 4),
1127                     ('b', 5),
1128                     ('c', 3),
1129                     ('d', 8),
1130                     ('e', 1);",
1131        )?;
1132
1133        let mut stmt = db.prepare(
1134            "SELECT x, sumint(y) OVER (
1135                   ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
1136                 ) AS sum_y
1137                 FROM t3 ORDER BY x;",
1138        )?;
1139
1140        let results: Vec<(String, i64)> = stmt
1141            .query([])?
1142            .map(|row| Ok((row.get("x")?, row.get("sum_y")?)))
1143            .collect()?;
1144        let expected = vec![
1145            ("a".to_owned(), 9),
1146            ("b".to_owned(), 12),
1147            ("c".to_owned(), 16),
1148            ("d".to_owned(), 12),
1149            ("e".to_owned(), 9),
1150        ];
1151        assert_eq!(expected, results);
1152        Ok(())
1153    }
1154
1155    #[test]
1156    fn test_sub_type() -> Result<()> {
1157        fn test_getsubtype(ctx: &Context<'_>) -> Result<i32> {
1158            Ok(ctx.get_subtype(0) as i32)
1159        }
1160        fn test_setsubtype(ctx: &Context<'_>) -> Result<(SqlFnArg, SubType)> {
1161            use std::os::raw::c_uint;
1162            let value = ctx.get_arg(0);
1163            let sub_type = ctx.get::<c_uint>(1)?;
1164            Ok((value, Some(sub_type)))
1165        }
1166        let db = Connection::open_in_memory()?;
1167        db.create_scalar_function(
1168            "test_getsubtype",
1169            1,
1170            FunctionFlags::SQLITE_UTF8,
1171            test_getsubtype,
1172        )?;
1173        db.create_scalar_function(
1174            "test_setsubtype",
1175            2,
1176            FunctionFlags::SQLITE_UTF8 | FunctionFlags::SQLITE_RESULT_SUBTYPE,
1177            test_setsubtype,
1178        )?;
1179        let result: i32 = db.one_column("SELECT test_getsubtype('hello');")?;
1180        assert_eq!(0, result);
1181
1182        let result: i32 = db.one_column("SELECT test_getsubtype(test_setsubtype('hello',123));")?;
1183        assert_eq!(123, result);
1184
1185        Ok(())
1186    }
1187}