diff options
| -rw-r--r-- | build.zig | 3 | ||||
| -rw-r--r-- | src/schema.zig | 355 |
2 files changed, 129 insertions, 229 deletions
@@ -2,11 +2,14 @@ const std = @import("std"); // Learn more about this file here: https://ziglang.org/learn/build-system pub fn build(b: *std.Build) void { + const optimize = b.standardOptimizeOption(.{ .preferred_optimize_mode = .ReleaseFast }); + const exe = b.addExecutable(.{ .name = "main", .root_module = b.createModule(.{ .root_source_file = b.path("src/main.zig"), .target = b.graph.host, + .optimize = optimize, }), }); diff --git a/src/schema.zig b/src/schema.zig index 54a0049..1ed2bf0 100644 --- a/src/schema.zig +++ b/src/schema.zig @@ -94,50 +94,37 @@ fn getIndexRootpage(allocator: std.mem.Allocator, file: *std.fs.File, page_size: } fn searchIndexForValue(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, index_rootpage: u32, search_value: []const u8, rowids: *std.ArrayList(u64)) !void { - var queue = std.ArrayList(u32){}; - defer queue.deinit(allocator); - try queue.append(allocator, index_rootpage); + // Read file into memory once + const file_size = try file.getEndPos(); + const file_data = try allocator.alloc(u8, file_size); + defer allocator.free(file_data); - // Track visited pages to avoid infinite loops - var visited = std.ArrayList(u32){}; - defer visited.deinit(allocator); + _ = try file.seekTo(0); + _ = try file.read(file_data); - var queue_idx: usize = 0; + // Use stack to iterate all pages (BFS approach - visit ALL index pages to find ALL matches) + var stack = std.ArrayList(u32){}; + defer stack.deinit(allocator); + try stack.append(allocator, index_rootpage); - while (queue_idx < queue.items.len) { - const page_num = queue.items[queue_idx]; - queue_idx += 1; + while (stack.items.len > 0) { + const page_num = stack.pop() orelse continue; if (page_num == 0) continue; - // Check if already visited - var already_visited = false; - for (visited.items) |v| { - if (v == page_num) { - already_visited = true; - break; - } - } - if (already_visited) continue; - try visited.append(allocator, page_num); - - const page_offset = (page_num - 1) * @as(u64, page_size); - var page_data = try allocator.alloc(u8, page_size); - defer allocator.free(page_data); - - _ = try file.seekTo(page_offset); - _ = try file.read(page_data); + const page_offset = (page_num - 1) * @as(usize, page_size); + if (page_offset + page_size > file_data.len) continue; + const page_data = file_data[page_offset .. page_offset + page_size]; const page_type = page_data[0]; if (page_type == 0x0a) { - // Leaf index page + // Leaf index page - search for matching values const num_cells = std.mem.readInt(u16, page_data[3..5], .big); for (0..num_cells) |i| { const offset = 8 + i * 2; if (offset + 2 > page_data.len) continue; - const cell_bytes: *const [2]u8 = page_data[offset .. offset + 2][0..2]; - const cell_ptr = std.mem.readInt(u16, cell_bytes, .big); + const cell_ptr = std.mem.readInt(u16, page_data[offset..][0..2], .big); if (cell_ptr >= page_data.len) continue; const cell_data = page_data[cell_ptr..]; @@ -151,57 +138,56 @@ fn searchIndexForValue(allocator: std.mem.Allocator, file: *std.fs.File, page_si if (header_size > record_data.len) continue; var header_pos = parsed.len; - var serial_types = std.ArrayList(u64){}; - defer serial_types.deinit(allocator); - - while (header_pos < header_size) { + // Parse serial types + var serial_types: [8]u64 = undefined; + var num_st: usize = 0; + while (header_pos < header_size and num_st < 8) { parsed = varint.parse(record_data[header_pos..]); - serial_types.append(allocator, parsed.value) catch break; + serial_types[num_st] = parsed.value; + num_st += 1; header_pos += parsed.len; } - if (serial_types.items.len > 0) { - const st = serial_types.items[0]; + if (num_st >= 2) { + const st = serial_types[0]; var body_pos: usize = header_size; - if (body_pos >= record_data.len) continue; if (st >= 13 and (st % 2) == 1) { const str_result = record.readString(record_data[body_pos..], st); if (std.mem.eql(u8, str_result.value, search_value)) { body_pos += str_result.len; - if (serial_types.items.len > 1 and body_pos < record_data.len) { - const rowid_st = serial_types.items[1]; - const rowid_result = record.readInt(record_data[body_pos..], rowid_st); - rowids.append(allocator, @as(u64, @intCast(rowid_result.value))) catch {}; - } + const rowid_st = serial_types[1]; + const rowid_result = record.readInt(record_data[body_pos..], rowid_st); + try rowids.append(allocator, @as(u64, @intCast(rowid_result.value))); } } } } } else if (page_type == 0x02) { - // Interior index page + // Interior index page - add ALL children to stack const num_cells = std.mem.readInt(u16, page_data[3..5], .big); const rightmost_ptr = std.mem.readInt(u32, page_data[8..12], .big); + // Add all cell pointers for (0..num_cells) |i| { const offset = 12 + i * 2; if (offset + 2 > page_data.len) continue; - const cell_bytes: *const [2]u8 = page_data[offset .. offset + 2][0..2]; - const cell_ptr = std.mem.readInt(u16, cell_bytes, .big); + const cell_ptr = std.mem.readInt(u16, page_data[offset..][0..2], .big); if (cell_ptr + 4 > page_data.len) continue; const cell_data = page_data[cell_ptr..]; const left_child_page = std.mem.readInt(u32, cell_data[0..4], .big); - queue.append(allocator, left_child_page) catch {}; + if (left_child_page > 0) { + try stack.append(allocator, left_child_page); + } } if (rightmost_ptr > 0) { - queue.append(allocator, rightmost_ptr) catch {}; + try stack.append(allocator, rightmost_ptr); } } } } - fn readRecordByRowid(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, table_rootpage: u32, target_rowid: u64, column_indices: []const usize, stdout: anytype) !void { try searchTableForRowid(allocator, file, page_size, table_rootpage, target_rowid, column_indices, stdout); } @@ -339,32 +325,52 @@ fn searchTableForRowid(allocator: std.mem.Allocator, file: *std.fs.File, page_si } } +inline fn serialTypeSize(st: u64) usize { + if (st == 0 or st == 8 or st == 9) { + return 0; + } else if (st >= 13 and (st % 2) == 1) { + return (st - 13) / 2; + } else if (st >= 12 and (st % 2) == 0) { + return (st - 12) / 2; + } else if (st >= 1 and st <= 6) { + return @as(usize, st); + } else if (st == 7) { + return 8; + } + return 0; +} + fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void { const page_type = page_data[0]; if (page_type != 0x0d) return; const num_cells = std.mem.readInt(u16, page_data[3..5], .big); + if (num_cells == 0) return; for (0..num_cells) |i| { const offset = 8 + i * 2; - const cell_ptr_bytes = page_data[offset .. offset + 2]; - const cell_ptr = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big); + if (offset + 2 > page_data.len) continue; + const cell_ptr = std.mem.readInt(u16, page_data[offset..][0..2], .big); + if (cell_ptr >= page_data.len) continue; const cell_data = page_data[cell_ptr..]; var parsed = varint.parse(cell_data); var pos = parsed.len; + if (pos >= cell_data.len) continue; parsed = varint.parse(cell_data[pos..]); const rowid = parsed.value; pos += parsed.len; + if (pos >= cell_data.len) continue; const record_data = cell_data[pos..]; parsed = varint.parse(record_data); const header_size = parsed.value; + if (header_size > record_data.len or header_size == 0) continue; var header_pos = parsed.len; - // Parse serial types into a fixed-size buffer instead of ArrayList + // Parse serial types var serial_types: [256]u64 = undefined; var num_columns: usize = 0; while (header_pos < header_size and num_columns < 256) { @@ -374,63 +380,48 @@ fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_ header_pos += parsed.len; } - // Check WHERE clause if present + // Check WHERE clause if present (early rejection) if (where_column_idx) |where_idx| { if (where_value) |expected_value| { if (where_idx >= num_columns) continue; - var where_body_pos: usize = header_size; - for (0..where_idx) |col| { - if (col >= num_columns) break; - const st = serial_types[col]; - if (st == 0 or st == 8 or st == 9) {} else if (st >= 13 and (st % 2) == 1) { - where_body_pos += (st - 13) / 2; - } else if (st >= 12 and (st % 2) == 0) { - where_body_pos += (st - 12) / 2; - } else if (st >= 1 and st <= 6) { - const int_result = record.readInt(record_data[where_body_pos..], st); - where_body_pos += int_result.len; - } else if (st == 7) { - where_body_pos += 8; - } - } - const st = serial_types[where_idx]; - var matches = false; + + // Fast path: check serial type first if (st >= 13 and (st % 2) == 1) { + // String comparison - check length first for early rejection + const expected_len = (st - 13) / 2; + if (expected_len != expected_value.len) continue; + + var where_body_pos: usize = header_size; + for (0..where_idx) |col| { + where_body_pos += serialTypeSize(serial_types[col]); + } + const str_result = record.readString(record_data[where_body_pos..], st); - matches = std.mem.eql(u8, str_result.value, expected_value); + if (!std.mem.eql(u8, str_result.value, expected_value)) continue; } else if (st >= 1 and st <= 6) { + var where_body_pos: usize = header_size; + for (0..where_idx) |col| { + where_body_pos += serialTypeSize(serial_types[col]); + } const int_result = record.readInt(record_data[where_body_pos..], st); const expected_int = std.fmt.parseInt(i64, expected_value, 10) catch 0; - matches = int_result.value == expected_int; + if (int_result.value != expected_int) continue; + } else { + continue; // Unsupported type for WHERE } - - if (!matches) continue; } } + // Print matching row for (column_indices, 0..) |column_idx, col_num| { if (col_num > 0) try stdout.print("|", .{}); - if (column_idx >= num_columns) continue; var body_pos: usize = header_size; for (0..column_idx) |col| { - if (col >= num_columns) break; - const st = serial_types[col]; - if (st == 0 or st == 8 or st == 9) { - // NULL, 0, or 1 - no data - } else if (st >= 13 and (st % 2) == 1) { - body_pos += (st - 13) / 2; - } else if (st >= 12 and (st % 2) == 0) { - body_pos += (st - 12) / 2; - } else if (st >= 1 and st <= 6) { - const int_result = record.readInt(record_data[body_pos..], st); - body_pos += int_result.len; - } else if (st == 7) { - body_pos += 8; // Float - } + body_pos += serialTypeSize(serial_types[col]); } const st = serial_types[column_idx]; @@ -443,12 +434,9 @@ fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_ } else if (st >= 1 and st <= 6) { const int_result = record.readInt(record_data[body_pos..], st); try stdout.print("{}", .{int_result.value}); - } else if (st == 7) {} else if (st >= 13 and (st % 2) == 1) { + } else if (st >= 13 and (st % 2) == 1) { const str_result = record.readString(record_data[body_pos..], st); try stdout.print("{s}", .{str_result.value}); - } else if (st >= 12 and (st % 2) == 0) { - const blob_len = (st - 12) / 2; - _ = blob_len; } } try stdout.print("\n", .{}); @@ -456,33 +444,56 @@ fn readLeafPageRows(page_data: []const u8, column_indices: []const usize, where_ } fn traverseBTree(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, page_num: u32, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void { - const page_offset = (page_num - 1) * @as(u64, page_size); - var page_data = try allocator.alloc(u8, page_size); - defer allocator.free(page_data); + // Read entire file into memory for fast random access + const file_size = try file.getEndPos(); + const file_data = try allocator.alloc(u8, file_size); + defer allocator.free(file_data); - _ = try file.seekTo(page_offset); - _ = try file.read(page_data); + _ = try file.seekTo(0); + _ = try file.read(file_data); - const page_type = page_data[0]; + var stack = std.ArrayList(u32){}; + defer stack.deinit(allocator); + try stack.append(allocator, page_num); - if (page_type == 0x0d) { - try readLeafPageRows(page_data, column_indices, where_column_idx, where_value, stdout); - } else if (page_type == 0x05) { - const num_cells = std.mem.readInt(u16, page_data[3..5], .big); - const rightmost_ptr = std.mem.readInt(u32, page_data[8..12], .big); + while (stack.items.len > 0) { + const current_page = stack.pop() orelse continue; - for (0..num_cells) |i| { - const offset = 12 + i * 2; - const cell_ptr_bytes = page_data[offset .. offset + 2]; - const cell_ptr = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big); + const page_offset = (current_page - 1) * @as(usize, page_size); + if (page_offset + page_size > file_data.len) continue; - const cell_data = page_data[cell_ptr..]; - const left_child_page = std.mem.readInt(u32, cell_data[0..4], .big); + const page_data = file_data[page_offset .. page_offset + page_size]; + const page_type = page_data[0]; - try traverseBTree(allocator, file, page_size, left_child_page, column_indices, where_column_idx, where_value, stdout); - } + if (page_type == 0x0d) { + try readLeafPageRows(page_data, column_indices, where_column_idx, where_value, stdout); + } else if (page_type == 0x05) { + const num_cells = std.mem.readInt(u16, page_data[3..5], .big); + const rightmost_ptr = std.mem.readInt(u32, page_data[8..12], .big); + + // Add rightmost first so it's processed last (stack LIFO) + if (rightmost_ptr != 0) { + try stack.append(allocator, rightmost_ptr); + } + + // Add children in reverse order for correct traversal + var i = num_cells; + while (i > 0) { + i -= 1; + const offset = 12 + i * 2; + if (offset + 2 > page_data.len) continue; + const cell_ptr_bytes = page_data[offset .. offset + 2]; + const cell_ptr = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big); + if (cell_ptr + 4 > page_data.len) continue; + + const cell_data = page_data[cell_ptr..]; + const left_child_page = std.mem.readInt(u32, cell_data[0..4], .big); - try traverseBTree(allocator, file, page_size, rightmost_ptr, column_indices, where_column_idx, where_value, stdout); + if (left_child_page != 0) { + try stack.append(allocator, left_child_page); + } + } + } } } @@ -855,122 +866,7 @@ pub fn readTableRowsMultiColumn(allocator: std.mem.Allocator, file: *std.fs.File } pub fn readTableRowsMultiColumnWhere(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, rootpage: u32, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void { - if (rootpage == 0) return; - - const page_offset = (rootpage - 1) * @as(u64, page_size); - var page_data = try allocator.alloc(u8, page_size); - defer allocator.free(page_data); - - _ = try file.seekTo(page_offset); - _ = try file.read(page_data); - - const page_type = page_data[0]; - if (page_type != 0x0d) return; - - const num_cells = std.mem.readInt(u16, page_data[3..5], .big); - - var cell_pointers = try allocator.alloc(u16, num_cells); - defer allocator.free(cell_pointers); - - for (0..num_cells) |i| { - const offset = 8 + i * 2; - const cell_ptr_bytes = page_data[offset .. offset + 2]; - cell_pointers[i] = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big); - } - - for (0..num_cells) |i| { - const cell_data = page_data[cell_pointers[i]..]; - - var parsed = varint.parse(cell_data); - var pos = parsed.len; - - parsed = varint.parse(cell_data[pos..]); - pos += parsed.len; - - const record_data = cell_data[pos..]; - parsed = varint.parse(record_data); - const header_size = parsed.value; - var header_pos = parsed.len; - - var serial_types = std.ArrayList(u64){}; - defer serial_types.deinit(allocator); - - while (header_pos < header_size) { - parsed = varint.parse(record_data[header_pos..]); - try serial_types.append(allocator, parsed.value); - header_pos += parsed.len; - } - - // Check WHERE condition if present - if (where_column_idx) |where_idx| { - if (where_value) |expected_value| { - // Calculate position for WHERE column - var where_body_pos: usize = header_size; - for (0..where_idx) |col| { - if (col >= serial_types.items.len) break; - const st = serial_types.items[col]; - if (st >= 13 and (st % 2) == 1) { - where_body_pos += (st - 13) / 2; - } else if (st >= 1 and st <= 6) { - const int_result = record.readInt(record_data[where_body_pos..], st); - where_body_pos += int_result.len; - } - } - - // Read WHERE column value - var matches = false; - if (where_idx < serial_types.items.len) { - const st = serial_types.items[where_idx]; - if (st >= 13 and (st % 2) == 1) { - const str_result = record.readString(record_data[where_body_pos..], st); - matches = std.mem.eql(u8, str_result.value, expected_value); - } else if (st >= 1 and st <= 6) { - const int_result = record.readInt(record_data[where_body_pos..], st); - // Try to parse expected_value as integer - const expected_int = std.fmt.parseInt(i64, expected_value, 10) catch 0; - matches = int_result.value == expected_int; - } - } - - // Skip this row if it doesn't match - if (!matches) continue; - } - } - - // For each requested column - for (column_indices, 0..) |column_idx, col_num| { - // Calculate position for this column - var body_pos: usize = header_size; - for (0..column_idx) |col| { - if (col >= serial_types.items.len) break; - const st = serial_types.items[col]; - if (st >= 13 and (st % 2) == 1) { - body_pos += (st - 13) / 2; - } else if (st >= 1 and st <= 6) { - const int_result = record.readInt(record_data[body_pos..], st); - body_pos += int_result.len; - } - } - - // Print separator if not first column - if (col_num > 0) { - try stdout.print("|", .{}); - } - - // Read and print the column value - if (column_idx < serial_types.items.len) { - const st = serial_types.items[column_idx]; - if (st >= 13 and (st % 2) == 1) { - const str_result = record.readString(record_data[body_pos..], st); - try stdout.print("{s}", .{str_result.value}); - } else if (st >= 1 and st <= 6) { - const int_result = record.readInt(record_data[body_pos..], st); - try stdout.print("{}", .{int_result.value}); - } - } - } - try stdout.print("\n", .{}); - } + try traverseBTree(allocator, file, page_size, rootpage, column_indices, where_column_idx, where_value, stdout); } pub fn readTableRowsWithIndex(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, table_name: []const u8, table_rootpage: u32, column_indices: []const usize, where_column: []const u8, where_value: []const u8, stdout: anytype) !void { @@ -983,5 +879,6 @@ pub fn readTableRowsWithIndex(allocator: std.mem.Allocator, file: *std.fs.File, _ = where_column; _ = where_value; _ = stdout; + // Temporarily disable to test table scan return error.NoIndexFound; } |