summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorLucas Faria Mendes <lucas.oliveira1676@etec.sp.gov.br>2025-12-05 14:56:15 +0000
committerLucas Faria Mendes <lucas.oliveira1676@etec.sp.gov.br>2025-12-05 14:56:15 +0000
commit11b7033a351226696290983811928d22ccc85256 (patch)
tree46dd2230986ff4c28da8ee047a97004bfaa5cfa3 /src
parentbb8bb6a15e83dc2b0a33133c8d738de2594d4f58 (diff)
downloadsqlite-zig-11b7033a351226696290983811928d22ccc85256.tar.gz
sqlite-zig-11b7033a351226696290983811928d22ccc85256.zip
codecrafters submit [skip ci]
Diffstat (limited to 'src')
-rw-r--r--src/schema.zig355
1 files changed, 126 insertions, 229 deletions
diff --git a/src/schema.zig b/src/schema.zig
index 54a0049..1ed2bf0 100644
--- a/src/schema.zig
+++ b/src/schema.zig
@@ -94,50 +94,37 @@ fn getIndexRootpage(allocator: std.mem.Allocator, file: *std.fs.File, page_size:
}
fn searchIndexForValue(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, index_rootpage: u32, search_value: []const u8, rowids: *std.ArrayList(u64)) !void {
- var queue = std.ArrayList(u32){};
- defer queue.deinit(allocator);
- try queue.append(allocator, index_rootpage);
+ // Read file into memory once
+ const file_size = try file.getEndPos();
+ const file_data = try allocator.alloc(u8, file_size);
+ defer allocator.free(file_data);
- // Track visited pages to avoid infinite loops
- var visited = std.ArrayList(u32){};
- defer visited.deinit(allocator);
+ _ = try file.seekTo(0);
+ _ = try file.read(file_data);
- var queue_idx: usize = 0;
+ // Use stack to iterate all pages (BFS approach - visit ALL index pages to find ALL matches)
+ var stack = std.ArrayList(u32){};
+ defer stack.deinit(allocator);
+ try stack.append(allocator, index_rootpage);
- while (queue_idx < queue.items.len) {
- const page_num = queue.items[queue_idx];
- queue_idx += 1;
+ while (stack.items.len > 0) {
+ const page_num = stack.pop() orelse continue;
if (page_num == 0) continue;
- // Check if already visited
- var already_visited = false;
- for (visited.items) |v| {
- if (v == page_num) {
- already_visited = true;
- break;
- }
- }
- if (already_visited) continue;
- try visited.append(allocator, page_num);
-
- const page_offset = (page_num - 1) * @as(u64, page_size);
- var page_data = try allocator.alloc(u8, page_size);
- defer allocator.free(page_data);
-
- _ = try file.seekTo(page_offset);
- _ = try file.read(page_data);
+ const page_offset = (page_num - 1) * @as(usize, page_size);
+ if (page_offset + page_size > file_data.len) continue;
+ const page_data = file_data[page_offset .. page_offset + page_size];
const page_type = page_data[0];
if (page_type == 0x0a) {
- // Leaf index page
+ // Leaf index page - search for matching values
const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
for (0..num_cells) |i| {
const offset = 8 + i * 2;
if (offset + 2 > page_data.len) continue;
- const cell_bytes: *const [2]u8 = page_data[offset .. offset + 2][0..2];
- const cell_ptr = std.mem.readInt(u16, cell_bytes, .big);
+ const cell_ptr = std.mem.readInt(u16, page_data[offset..][0..2], .big);
if (cell_ptr >= page_data.len) continue;
const cell_data = page_data[cell_ptr..];
@@ -151,57 +138,56 @@ fn searchIndexForValue(allocator: std.mem.Allocator, file: *std.fs.File, page_si
if (header_size > record_data.len) continue;
var header_pos = parsed.len;
- var serial_types = std.ArrayList(u64){};
- defer serial_types.deinit(allocator);
-
- while (header_pos < header_size) {
+ // Parse serial types
+ var serial_types: [8]u64 = undefined;
+ var num_st: usize = 0;
+ while (header_pos < header_size and num_st < 8) {
parsed = varint.parse(record_data[header_pos..]);
- serial_types.append(allocator, parsed.value) catch break;
+ serial_types[num_st] = parsed.value;
+ num_st += 1;
header_pos += parsed.len;
}
- if (serial_types.items.len > 0) {
- const st = serial_types.items[0];
+ if (num_st >= 2) {
+ const st = serial_types[0];
var body_pos: usize = header_size;
- if (body_pos >= record_data.len) continue;
if (st >= 13 and (st % 2) == 1) {
const str_result = record.readString(record_data[body_pos..], st);
if (std.mem.eql(u8, str_result.value, search_value)) {
body_pos += str_result.len;
- if (serial_types.items.len > 1 and body_pos < record_data.len) {
- const rowid_st = serial_types.items[1];
- const rowid_result = record.readInt(record_data[body_pos..], rowid_st);
- rowids.append(allocator, @as(u64, @intCast(rowid_result.value))) catch {};
- }
+ const rowid_st = serial_types[1];
+ const rowid_result = record.readInt(record_data[body_pos..], rowid_st);
+ try rowids.append(allocator, @as(u64, @intCast(rowid_result.value)));
}
}
}
}
} else if (page_type == 0x02) {
- // Interior index page
+ // Interior index page - add ALL children to stack
const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
const rightmost_ptr = std.mem.readInt(u32, page_data[8..12], .big);
+ // Add all cell pointers
for (0..num_cells) |i| {
const offset = 12 + i * 2;
if (offset + 2 > page_data.len) continue;
- const cell_bytes: *const [2]u8 = page_data[offset .. offset + 2][0..2];
- const cell_ptr = std.mem.readInt(u16, cell_bytes, .big);
+ const cell_ptr = std.mem.readInt(u16, page_data[offset..][0..2], .big);
if (cell_ptr + 4 > page_data.len) continue;
const cell_data = page_data[cell_ptr..];
const left_child_page = std.mem.readInt(u32, cell_data[0..4], .big);
- queue.append(allocator, left_child_page) catch {};
+ if (left_child_page > 0) {
+ try stack.append(allocator, left_child_page);
+ }
}
if (rightmost_ptr > 0) {
- queue.append(allocator, rightmost_ptr) catch {};
+ try stack.append(allocator, rightmost_ptr);
}
}
}
}
-
fn readRecordByRowid(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, table_rootpage: u32, target_rowid: u64, column_indices: []const usize, stdout: anytype) !void {
try searchTableForRowid(allocator, file, page_size, table_rootpage, target_rowid, column_indices, stdout);
}
@@ -339,32 +325,52 @@ fn searchTableForRowid(allocator: std.mem.Allocator, file: *std.fs.File, page_si
}
}
+inline fn serialTypeSize(st: u64) usize {
+ if (st == 0 or st == 8 or st == 9) {
+ return 0;
+ } else if (st >= 13 and (st % 2) == 1) {
+ return (st - 13) / 2;
+ } else if (st >= 12 and (st % 2) == 0) {
+ return (st - 12) / 2;
+ } else if (st >= 1 and st <= 6) {
+ return @as(usize, st);
+ } else if (st == 7) {
+ return 8;
+ }
+ return 0;
+}
+
fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void {
const page_type = page_data[0];
if (page_type != 0x0d) return;
const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
+ if (num_cells == 0) return;
for (0..num_cells) |i| {
const offset = 8 + i * 2;
- const cell_ptr_bytes = page_data[offset .. offset + 2];
- const cell_ptr = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big);
+ if (offset + 2 > page_data.len) continue;
+ const cell_ptr = std.mem.readInt(u16, page_data[offset..][0..2], .big);
+ if (cell_ptr >= page_data.len) continue;
const cell_data = page_data[cell_ptr..];
var parsed = varint.parse(cell_data);
var pos = parsed.len;
+ if (pos >= cell_data.len) continue;
parsed = varint.parse(cell_data[pos..]);
const rowid = parsed.value;
pos += parsed.len;
+ if (pos >= cell_data.len) continue;
const record_data = cell_data[pos..];
parsed = varint.parse(record_data);
const header_size = parsed.value;
+ if (header_size > record_data.len or header_size == 0) continue;
var header_pos = parsed.len;
- // Parse serial types into a fixed-size buffer instead of ArrayList
+ // Parse serial types
var serial_types: [256]u64 = undefined;
var num_columns: usize = 0;
while (header_pos < header_size and num_columns < 256) {
@@ -374,63 +380,48 @@ fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_
header_pos += parsed.len;
}
- // Check WHERE clause if present
+ // Check WHERE clause if present (early rejection)
if (where_column_idx) |where_idx| {
if (where_value) |expected_value| {
if (where_idx >= num_columns) continue;
- var where_body_pos: usize = header_size;
- for (0..where_idx) |col| {
- if (col >= num_columns) break;
- const st = serial_types[col];
- if (st == 0 or st == 8 or st == 9) {} else if (st >= 13 and (st % 2) == 1) {
- where_body_pos += (st - 13) / 2;
- } else if (st >= 12 and (st % 2) == 0) {
- where_body_pos += (st - 12) / 2;
- } else if (st >= 1 and st <= 6) {
- const int_result = record.readInt(record_data[where_body_pos..], st);
- where_body_pos += int_result.len;
- } else if (st == 7) {
- where_body_pos += 8;
- }
- }
-
const st = serial_types[where_idx];
- var matches = false;
+
+ // Fast path: check serial type first
if (st >= 13 and (st % 2) == 1) {
+ // String comparison - check length first for early rejection
+ const expected_len = (st - 13) / 2;
+ if (expected_len != expected_value.len) continue;
+
+ var where_body_pos: usize = header_size;
+ for (0..where_idx) |col| {
+ where_body_pos += serialTypeSize(serial_types[col]);
+ }
+
const str_result = record.readString(record_data[where_body_pos..], st);
- matches = std.mem.eql(u8, str_result.value, expected_value);
+ if (!std.mem.eql(u8, str_result.value, expected_value)) continue;
} else if (st >= 1 and st <= 6) {
+ var where_body_pos: usize = header_size;
+ for (0..where_idx) |col| {
+ where_body_pos += serialTypeSize(serial_types[col]);
+ }
const int_result = record.readInt(record_data[where_body_pos..], st);
const expected_int = std.fmt.parseInt(i64, expected_value, 10) catch 0;
- matches = int_result.value == expected_int;
+ if (int_result.value != expected_int) continue;
+ } else {
+ continue; // Unsupported type for WHERE
}
-
- if (!matches) continue;
}
}
+ // Print matching row
for (column_indices, 0..) |column_idx, col_num| {
if (col_num > 0) try stdout.print("|", .{});
-
if (column_idx >= num_columns) continue;
var body_pos: usize = header_size;
for (0..column_idx) |col| {
- if (col >= num_columns) break;
- const st = serial_types[col];
- if (st == 0 or st == 8 or st == 9) {
- // NULL, 0, or 1 - no data
- } else if (st >= 13 and (st % 2) == 1) {
- body_pos += (st - 13) / 2;
- } else if (st >= 12 and (st % 2) == 0) {
- body_pos += (st - 12) / 2;
- } else if (st >= 1 and st <= 6) {
- const int_result = record.readInt(record_data[body_pos..], st);
- body_pos += int_result.len;
- } else if (st == 7) {
- body_pos += 8; // Float
- }
+ body_pos += serialTypeSize(serial_types[col]);
}
const st = serial_types[column_idx];
@@ -443,12 +434,9 @@ fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_
} else if (st >= 1 and st <= 6) {
const int_result = record.readInt(record_data[body_pos..], st);
try stdout.print("{}", .{int_result.value});
- } else if (st == 7) {} else if (st >= 13 and (st % 2) == 1) {
+ } else if (st >= 13 and (st % 2) == 1) {
const str_result = record.readString(record_data[body_pos..], st);
try stdout.print("{s}", .{str_result.value});
- } else if (st >= 12 and (st % 2) == 0) {
- const blob_len = (st - 12) / 2;
- _ = blob_len;
}
}
try stdout.print("\n", .{});
@@ -456,33 +444,56 @@ fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_
}
fn traverseBTree(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, page_num: u32, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void {
- const page_offset = (page_num - 1) * @as(u64, page_size);
- var page_data = try allocator.alloc(u8, page_size);
- defer allocator.free(page_data);
+ // Read entire file into memory for fast random access
+ const file_size = try file.getEndPos();
+ const file_data = try allocator.alloc(u8, file_size);
+ defer allocator.free(file_data);
- _ = try file.seekTo(page_offset);
- _ = try file.read(page_data);
+ _ = try file.seekTo(0);
+ _ = try file.read(file_data);
- const page_type = page_data[0];
+ var stack = std.ArrayList(u32){};
+ defer stack.deinit(allocator);
+ try stack.append(allocator, page_num);
- if (page_type == 0x0d) {
- try readLeafPageRows(page_data, column_indices, where_column_idx, where_value, stdout);
- } else if (page_type == 0x05) {
- const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
- const rightmost_ptr = std.mem.readInt(u32, page_data[8..12], .big);
+ while (stack.items.len > 0) {
+ const current_page = stack.pop() orelse continue;
- for (0..num_cells) |i| {
- const offset = 12 + i * 2;
- const cell_ptr_bytes = page_data[offset .. offset + 2];
- const cell_ptr = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big);
+ const page_offset = (current_page - 1) * @as(usize, page_size);
+ if (page_offset + page_size > file_data.len) continue;
- const cell_data = page_data[cell_ptr..];
- const left_child_page = std.mem.readInt(u32, cell_data[0..4], .big);
+ const page_data = file_data[page_offset .. page_offset + page_size];
+ const page_type = page_data[0];
- try traverseBTree(allocator, file, page_size, left_child_page, column_indices, where_column_idx, where_value, stdout);
- }
+ if (page_type == 0x0d) {
+ try readLeafPageRows(page_data, column_indices, where_column_idx, where_value, stdout);
+ } else if (page_type == 0x05) {
+ const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
+ const rightmost_ptr = std.mem.readInt(u32, page_data[8..12], .big);
+
+ // Add rightmost first so it's processed last (stack LIFO)
+ if (rightmost_ptr != 0) {
+ try stack.append(allocator, rightmost_ptr);
+ }
+
+ // Add children in reverse order for correct traversal
+ var i = num_cells;
+ while (i > 0) {
+ i -= 1;
+ const offset = 12 + i * 2;
+ if (offset + 2 > page_data.len) continue;
+ const cell_ptr_bytes = page_data[offset .. offset + 2];
+ const cell_ptr = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big);
+ if (cell_ptr + 4 > page_data.len) continue;
+
+ const cell_data = page_data[cell_ptr..];
+ const left_child_page = std.mem.readInt(u32, cell_data[0..4], .big);
- try traverseBTree(allocator, file, page_size, rightmost_ptr, column_indices, where_column_idx, where_value, stdout);
+ if (left_child_page != 0) {
+ try stack.append(allocator, left_child_page);
+ }
+ }
+ }
}
}
@@ -855,122 +866,7 @@ pub fn readTableRowsMultiColumn(allocator: std.mem.Allocator, file: *std.fs.File
}
pub fn readTableRowsMultiColumnWhere(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, rootpage: u32, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void {
- if (rootpage == 0) return;
-
- const page_offset = (rootpage - 1) * @as(u64, page_size);
- var page_data = try allocator.alloc(u8, page_size);
- defer allocator.free(page_data);
-
- _ = try file.seekTo(page_offset);
- _ = try file.read(page_data);
-
- const page_type = page_data[0];
- if (page_type != 0x0d) return;
-
- const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
-
- var cell_pointers = try allocator.alloc(u16, num_cells);
- defer allocator.free(cell_pointers);
-
- for (0..num_cells) |i| {
- const offset = 8 + i * 2;
- const cell_ptr_bytes = page_data[offset .. offset + 2];
- cell_pointers[i] = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big);
- }
-
- for (0..num_cells) |i| {
- const cell_data = page_data[cell_pointers[i]..];
-
- var parsed = varint.parse(cell_data);
- var pos = parsed.len;
-
- parsed = varint.parse(cell_data[pos..]);
- pos += parsed.len;
-
- const record_data = cell_data[pos..];
- parsed = varint.parse(record_data);
- const header_size = parsed.value;
- var header_pos = parsed.len;
-
- var serial_types = std.ArrayList(u64){};
- defer serial_types.deinit(allocator);
-
- while (header_pos < header_size) {
- parsed = varint.parse(record_data[header_pos..]);
- try serial_types.append(allocator, parsed.value);
- header_pos += parsed.len;
- }
-
- // Check WHERE condition if present
- if (where_column_idx) |where_idx| {
- if (where_value) |expected_value| {
- // Calculate position for WHERE column
- var where_body_pos: usize = header_size;
- for (0..where_idx) |col| {
- if (col >= serial_types.items.len) break;
- const st = serial_types.items[col];
- if (st >= 13 and (st % 2) == 1) {
- where_body_pos += (st - 13) / 2;
- } else if (st >= 1 and st <= 6) {
- const int_result = record.readInt(record_data[where_body_pos..], st);
- where_body_pos += int_result.len;
- }
- }
-
- // Read WHERE column value
- var matches = false;
- if (where_idx < serial_types.items.len) {
- const st = serial_types.items[where_idx];
- if (st >= 13 and (st % 2) == 1) {
- const str_result = record.readString(record_data[where_body_pos..], st);
- matches = std.mem.eql(u8, str_result.value, expected_value);
- } else if (st >= 1 and st <= 6) {
- const int_result = record.readInt(record_data[where_body_pos..], st);
- // Try to parse expected_value as integer
- const expected_int = std.fmt.parseInt(i64, expected_value, 10) catch 0;
- matches = int_result.value == expected_int;
- }
- }
-
- // Skip this row if it doesn't match
- if (!matches) continue;
- }
- }
-
- // For each requested column
- for (column_indices, 0..) |column_idx, col_num| {
- // Calculate position for this column
- var body_pos: usize = header_size;
- for (0..column_idx) |col| {
- if (col >= serial_types.items.len) break;
- const st = serial_types.items[col];
- if (st >= 13 and (st % 2) == 1) {
- body_pos += (st - 13) / 2;
- } else if (st >= 1 and st <= 6) {
- const int_result = record.readInt(record_data[body_pos..], st);
- body_pos += int_result.len;
- }
- }
-
- // Print separator if not first column
- if (col_num > 0) {
- try stdout.print("|", .{});
- }
-
- // Read and print the column value
- if (column_idx < serial_types.items.len) {
- const st = serial_types.items[column_idx];
- if (st >= 13 and (st % 2) == 1) {
- const str_result = record.readString(record_data[body_pos..], st);
- try stdout.print("{s}", .{str_result.value});
- } else if (st >= 1 and st <= 6) {
- const int_result = record.readInt(record_data[body_pos..], st);
- try stdout.print("{}", .{int_result.value});
- }
- }
- }
- try stdout.print("\n", .{});
- }
+ try traverseBTree(allocator, file, page_size, rootpage, column_indices, where_column_idx, where_value, stdout);
}
pub fn readTableRowsWithIndex(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, table_name: []const u8, table_rootpage: u32, column_indices: []const usize, where_column: []const u8, where_value: []const u8, stdout: anytype) !void {
@@ -983,5 +879,6 @@ pub fn readTableRowsWithIndex(allocator: std.mem.Allocator, file: *std.fs.File,
_ = where_column;
_ = where_value;
_ = stdout;
+ // Temporarily disable to test table scan
return error.NoIndexFound;
}