1use std::ops::Deref;
4
5use crate::ffi;
6use crate::types::{ToSql, ToSqlOutput, ValueRef};
7use crate::{Connection, DatabaseName, Result, Row};
8
9pub struct Sql {
10 buf: String,
11}
12
13impl Sql {
14 pub fn new() -> Self {
15 Self { buf: String::new() }
16 }
17
18 pub fn push_pragma(
19 &mut self,
20 schema_name: Option<DatabaseName<'_>>,
21 pragma_name: &str,
22 ) -> Result<()> {
23 self.push_keyword("PRAGMA")?;
24 self.push_space();
25 if let Some(schema_name) = schema_name {
26 self.push_schema_name(schema_name);
27 self.push_dot();
28 }
29 self.push_keyword(pragma_name)
30 }
31
32 pub fn push_keyword(&mut self, keyword: &str) -> Result<()> {
33 if !keyword.is_empty() && is_identifier(keyword) {
34 self.buf.push_str(keyword);
35 Ok(())
36 } else {
37 Err(err!(ffi::SQLITE_MISUSE, "Invalid keyword \"{keyword}\""))
38 }
39 }
40
41 pub fn push_schema_name(&mut self, schema_name: DatabaseName<'_>) {
42 match schema_name {
43 DatabaseName::Main => self.buf.push_str("main"),
44 DatabaseName::Temp => self.buf.push_str("temp"),
45 DatabaseName::Attached(s) => self.push_identifier(s),
46 DatabaseName::C(s) => self.push_identifier(s.to_str().expect("invalid database name")),
47 };
48 }
49
50 pub fn push_identifier(&mut self, s: &str) {
51 if is_identifier(s) {
52 self.buf.push_str(s);
53 } else {
54 self.wrap_and_escape(s, '"');
55 }
56 }
57
58 pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> {
59 let value = value.to_sql()?;
60 let value = match value {
61 ToSqlOutput::Borrowed(v) => v,
62 ToSqlOutput::Owned(ref v) => ValueRef::from(v),
63 #[cfg(feature = "blob")]
64 ToSqlOutput::ZeroBlob(_) => {
65 return Err(err!(ffi::SQLITE_MISUSE, "Unsupported value \"{value:?}\""));
66 }
67 #[cfg(feature = "functions")]
68 ToSqlOutput::Arg(_) => {
69 return Err(err!(ffi::SQLITE_MISUSE, "Unsupported value \"{value:?}\""));
70 }
71 #[cfg(feature = "array")]
72 ToSqlOutput::Array(_) => {
73 return Err(err!(ffi::SQLITE_MISUSE, "Unsupported value \"{value:?}\""));
74 }
75 };
76 match value {
77 ValueRef::Integer(i) => {
78 self.push_int(i);
79 }
80 ValueRef::Real(r) => {
81 self.push_real(r);
82 }
83 ValueRef::Text(s) => {
84 let s = std::str::from_utf8(s)?;
85 self.push_string_literal(s);
86 }
87 _ => {
88 return Err(err!(ffi::SQLITE_MISUSE, "Unsupported value \"{value:?}\""));
89 }
90 };
91 Ok(())
92 }
93
94 pub fn push_string_literal(&mut self, s: &str) {
95 self.wrap_and_escape(s, '\'');
96 }
97
98 pub fn push_int(&mut self, i: i64) {
99 self.buf.push_str(&i.to_string());
100 }
101
102 pub fn push_real(&mut self, f: f64) {
103 self.buf.push_str(&f.to_string());
104 }
105
106 pub fn push_space(&mut self) {
107 self.buf.push(' ');
108 }
109
110 pub fn push_dot(&mut self) {
111 self.buf.push('.');
112 }
113
114 pub fn push_equal_sign(&mut self) {
115 self.buf.push('=');
116 }
117
118 pub fn open_brace(&mut self) {
119 self.buf.push('(');
120 }
121
122 pub fn close_brace(&mut self) {
123 self.buf.push(')');
124 }
125
126 pub fn as_str(&self) -> &str {
127 &self.buf
128 }
129
130 fn wrap_and_escape(&mut self, s: &str, quote: char) {
131 self.buf.push(quote);
132 let chars = s.chars();
133 for ch in chars {
134 if ch == quote {
136 self.buf.push(ch);
137 }
138 self.buf.push(ch);
139 }
140 self.buf.push(quote);
141 }
142}
143
144impl Deref for Sql {
145 type Target = str;
146
147 fn deref(&self) -> &str {
148 self.as_str()
149 }
150}
151
152impl Connection {
153 pub fn pragma_query_value<T, F>(
161 &self,
162 schema_name: Option<DatabaseName<'_>>,
163 pragma_name: &str,
164 f: F,
165 ) -> Result<T>
166 where
167 F: FnOnce(&Row<'_>) -> Result<T>,
168 {
169 let mut query = Sql::new();
170 query.push_pragma(schema_name, pragma_name)?;
171 self.query_row(&query, [], f)
172 }
173
174 pub fn pragma_query<F>(
179 &self,
180 schema_name: Option<DatabaseName<'_>>,
181 pragma_name: &str,
182 mut f: F,
183 ) -> Result<()>
184 where
185 F: FnMut(&Row<'_>) -> Result<()>,
186 {
187 let mut query = Sql::new();
188 query.push_pragma(schema_name, pragma_name)?;
189 let mut stmt = self.prepare(&query)?;
190 let mut rows = stmt.query([])?;
191 while let Some(result_row) = rows.next()? {
192 let row = result_row;
193 f(row)?;
194 }
195 Ok(())
196 }
197
198 pub fn pragma<F, V>(
208 &self,
209 schema_name: Option<DatabaseName<'_>>,
210 pragma_name: &str,
211 pragma_value: V,
212 mut f: F,
213 ) -> Result<()>
214 where
215 F: FnMut(&Row<'_>) -> Result<()>,
216 V: ToSql,
217 {
218 let mut sql = Sql::new();
219 sql.push_pragma(schema_name, pragma_name)?;
220 sql.open_brace();
224 sql.push_value(&pragma_value)?;
225 sql.close_brace();
226 let mut stmt = self.prepare(&sql)?;
227 let mut rows = stmt.query([])?;
228 while let Some(result_row) = rows.next()? {
229 let row = result_row;
230 f(row)?;
231 }
232 Ok(())
233 }
234
235 pub fn pragma_update<V>(
240 &self,
241 schema_name: Option<DatabaseName<'_>>,
242 pragma_name: &str,
243 pragma_value: V,
244 ) -> Result<()>
245 where
246 V: ToSql,
247 {
248 let mut sql = Sql::new();
249 sql.push_pragma(schema_name, pragma_name)?;
250 sql.push_equal_sign();
254 sql.push_value(&pragma_value)?;
255 self.execute_batch(&sql)
256 }
257
258 pub fn pragma_update_and_check<F, T, V>(
262 &self,
263 schema_name: Option<DatabaseName<'_>>,
264 pragma_name: &str,
265 pragma_value: V,
266 f: F,
267 ) -> Result<T>
268 where
269 F: FnOnce(&Row<'_>) -> Result<T>,
270 V: ToSql,
271 {
272 let mut sql = Sql::new();
273 sql.push_pragma(schema_name, pragma_name)?;
274 sql.push_equal_sign();
278 sql.push_value(&pragma_value)?;
279 self.query_row(&sql, [], f)
280 }
281}
282
283fn is_identifier(s: &str) -> bool {
284 let chars = s.char_indices();
285 for (i, ch) in chars {
286 if i == 0 {
287 if !is_identifier_start(ch) {
288 return false;
289 }
290 } else if !is_identifier_continue(ch) {
291 return false;
292 }
293 }
294 true
295}
296
297fn is_identifier_start(c: char) -> bool {
298 c.is_ascii_uppercase() || c == '_' || c.is_ascii_lowercase() || c > '\x7F'
299}
300
301fn is_identifier_continue(c: char) -> bool {
302 c == '$'
303 || c.is_ascii_digit()
304 || c.is_ascii_uppercase()
305 || c == '_'
306 || c.is_ascii_lowercase()
307 || c > '\x7F'
308}
309
310#[cfg(test)]
311mod test {
312 use super::Sql;
313 use crate::pragma;
314 use crate::{Connection, DatabaseName, Result};
315
316 #[test]
317 fn pragma_query_value() -> Result<()> {
318 let db = Connection::open_in_memory()?;
319 let user_version: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?;
320 assert_eq!(0, user_version);
321 Ok(())
322 }
323
324 #[test]
325 #[cfg(feature = "modern_sqlite")]
326 fn pragma_func_query_value() -> Result<()> {
327 let db = Connection::open_in_memory()?;
328 let user_version: i32 = db.one_column("SELECT user_version FROM pragma_user_version")?;
329 assert_eq!(0, user_version);
330 Ok(())
331 }
332
333 #[test]
334 fn pragma_query_no_schema() -> Result<()> {
335 let db = Connection::open_in_memory()?;
336 let mut user_version = -1;
337 db.pragma_query(None, "user_version", |row| {
338 user_version = row.get(0)?;
339 Ok(())
340 })?;
341 assert_eq!(0, user_version);
342 Ok(())
343 }
344
345 #[test]
346 fn pragma_query_with_schema() -> Result<()> {
347 let db = Connection::open_in_memory()?;
348 let mut user_version = -1;
349 db.pragma_query(Some(DatabaseName::Main), "user_version", |row| {
350 user_version = row.get(0)?;
351 Ok(())
352 })?;
353 assert_eq!(0, user_version);
354 Ok(())
355 }
356
357 #[test]
358 fn pragma() -> Result<()> {
359 let db = Connection::open_in_memory()?;
360 let mut columns = Vec::new();
361 db.pragma(None, "table_info", "sqlite_master", |row| {
362 let column: String = row.get(1)?;
363 columns.push(column);
364 Ok(())
365 })?;
366 assert_eq!(5, columns.len());
367 Ok(())
368 }
369
370 #[test]
371 #[cfg(feature = "modern_sqlite")]
372 fn pragma_func() -> Result<()> {
373 let db = Connection::open_in_memory()?;
374 let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?1)")?;
375 let mut columns = Vec::new();
376 let mut rows = table_info.query(["sqlite_master"])?;
377
378 while let Some(row) = rows.next()? {
379 let column: String = row.get(1)?;
380 columns.push(column);
381 }
382 assert_eq!(5, columns.len());
383 Ok(())
384 }
385
386 #[test]
387 fn pragma_update() -> Result<()> {
388 let db = Connection::open_in_memory()?;
389 db.pragma_update(None, "user_version", 1)
390 }
391
392 #[test]
393 fn pragma_update_and_check() -> Result<()> {
394 let db = Connection::open_in_memory()?;
395 let journal_mode: String =
396 db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get(0))?;
397 assert!(
398 journal_mode == "off" || journal_mode == "memory",
399 "mode: {journal_mode:?}"
400 );
401 let mode =
403 db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get::<_, String>(0))?;
404 assert!(mode == "off" || mode == "memory", "mode: {mode:?}");
405
406 let param: &dyn crate::ToSql = &"OFF";
407 let mode =
408 db.pragma_update_and_check(None, "journal_mode", param, |row| row.get::<_, String>(0))?;
409 assert!(mode == "off" || mode == "memory", "mode: {mode:?}");
410 Ok(())
411 }
412
413 #[test]
414 fn is_identifier() {
415 assert!(pragma::is_identifier("full"));
416 assert!(pragma::is_identifier("r2d2"));
417 assert!(!pragma::is_identifier("sp ce"));
418 assert!(!pragma::is_identifier("semi;colon"));
419 }
420
421 #[test]
422 fn double_quote() {
423 let mut sql = Sql::new();
424 sql.push_schema_name(DatabaseName::Attached(r#"schema";--"#));
425 assert_eq!(r#""schema"";--""#, sql.as_str());
426 }
427
428 #[test]
429 fn wrap_and_escape() {
430 let mut sql = Sql::new();
431 sql.push_string_literal("value'; --");
432 assert_eq!("'value''; --'", sql.as_str());
433 }
434
435 #[test]
436 fn locking_mode() -> Result<()> {
437 let db = Connection::open_in_memory()?;
438 let r = db.pragma_update(None, "locking_mode", "exclusive");
439 if cfg!(feature = "extra_check") {
440 r.unwrap_err();
441 } else {
442 r?;
443 }
444 Ok(())
445 }
446}