diff options
| author | Lucas Faria Mendes <lucas.oliveira1676@etec.sp.gov.br> | 2025-12-05 09:56:34 +0000 |
|---|---|---|
| committer | Lucas Faria Mendes <lucas.oliveira1676@etec.sp.gov.br> | 2025-12-05 09:56:34 +0000 |
| commit | ebe1de16c418897760af72246aad5d6098c56cfc (patch) | |
| tree | 42c51d4fee0ccaafa360d5539f5f188c38da8a27 /src | |
| parent | 696362172c20e4e08a412f3412b711d0a78811db (diff) | |
| download | sqlite-zig-ebe1de16c418897760af72246aad5d6098c56cfc.tar.gz sqlite-zig-ebe1de16c418897760af72246aad5d6098c56cfc.zip | |
codecrafters submit [skip ci]
Diffstat (limited to 'src')
| -rwxr-xr-x | src/main.zig | 39 | ||||
| -rw-r--r-- | src/schema.zig | 119 |
2 files changed, 155 insertions, 3 deletions
diff --git a/src/main.zig b/src/main.zig index 4394d13..6e1f547 100755 --- a/src/main.zig +++ b/src/main.zig @@ -48,9 +48,36 @@ pub fn main() !void { const select_len: usize = 6; // length of "SELECT" or "select" const column_part = std.mem.trim(u8, args[2][select_len..from_idx], " \t\n"); - // Extract table name (after FROM) + // Find WHERE clause if it exists + const where_idx_upper = std.mem.indexOf(u8, args[2][from_idx..], "WHERE"); + const where_idx_lower = std.mem.indexOf(u8, args[2][from_idx..], "where"); + const where_idx_rel = where_idx_upper orelse where_idx_lower; + + // Extract table name (after FROM, before WHERE if exists) const from_end = from_idx + 4; // length of "FROM" - const table_name = std.mem.trim(u8, args[2][from_end..], " \t\n"); + const table_end = if (where_idx_rel) |idx| from_idx + idx else args[2].len; + const table_name = std.mem.trim(u8, args[2][from_end..table_end], " \t\n"); + + // Parse WHERE clause if it exists + var where_column: ?[]const u8 = null; + var where_value: ?[]const u8 = null; + if (where_idx_rel) |idx| { + const where_start = from_idx + idx + 5; // +5 for "WHERE" + const where_clause = std.mem.trim(u8, args[2][where_start..], " \t\n"); + + // Parse WHERE clause: "column = 'value'" + if (std.mem.indexOf(u8, where_clause, "=")) |eq_idx| { + where_column = std.mem.trim(u8, where_clause[0..eq_idx], " \t\n"); + var value_part = std.mem.trim(u8, where_clause[eq_idx + 1 ..], " \t\n"); + + // Remove quotes from value + if (value_part.len >= 2 and value_part[0] == '\'' and value_part[value_part.len - 1] == '\'') { + where_value = value_part[1 .. value_part.len - 1]; + } else { + where_value = value_part; + } + } + } // Check if this is COUNT(*) if (std.mem.indexOf(u8, column_part, "count(") != null or std.mem.indexOf(u8, column_part, "COUNT(") != null) { @@ -75,6 +102,12 @@ pub fn main() !void { // Get the table's root page const rootpage = try schema.getRootpage(allocator, &file, page_size, table_name); + // Get WHERE column index if present + var where_column_idx: ?usize = null; + if (where_column) |col| { + where_column_idx = try schema.parseColumnIndex(create_sql, col); + } + if (column_list.items.len == 1) { // Single column query const column_idx = try schema.parseColumnIndex(create_sql, column_list.items[0]); @@ -89,7 +122,7 @@ pub fn main() !void { try column_indices.append(allocator, idx); } - try schema.readTableRowsMultiColumn(allocator, &file, page_size, rootpage, column_indices.items, stdout); + try schema.readTableRowsMultiColumnWhere(allocator, &file, page_size, rootpage, column_indices.items, where_column_idx, where_value, stdout); } } } else { diff --git a/src/schema.zig b/src/schema.zig index 5a6d60b..fce45ff 100644 --- a/src/schema.zig +++ b/src/schema.zig @@ -373,3 +373,122 @@ pub fn readTableRowsMultiColumn(allocator: std.mem.Allocator, file: *std.fs.File try stdout.print("\n", .{}); } } + +pub fn readTableRowsMultiColumnWhere(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, rootpage: u32, column_indices: []const usize, where_column_idx: ?usize, where_value: ?[]const u8, stdout: anytype) !void { + if (rootpage == 0) return; + + const page_offset = (rootpage - 1) * @as(u64, page_size); + var page_data = try allocator.alloc(u8, page_size); + defer allocator.free(page_data); + + _ = try file.seekTo(page_offset); + _ = try file.read(page_data); + + const page_type = page_data[0]; + if (page_type != 0x0d) return; + + const num_cells = std.mem.readInt(u16, page_data[3..5], .big); + + var cell_pointers = try allocator.alloc(u16, num_cells); + defer allocator.free(cell_pointers); + + for (0..num_cells) |i| { + const offset = 8 + i * 2; + const cell_ptr_bytes = page_data[offset .. offset + 2]; + cell_pointers[i] = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big); + } + + for (0..num_cells) |i| { + const cell_data = page_data[cell_pointers[i]..]; + + var parsed = varint.parse(cell_data); + var pos = parsed.len; + + parsed = varint.parse(cell_data[pos..]); + pos += parsed.len; + + const record_data = cell_data[pos..]; + parsed = varint.parse(record_data); + const header_size = parsed.value; + var header_pos = parsed.len; + + var serial_types = std.ArrayList(u64){}; + defer serial_types.deinit(allocator); + + while (header_pos < header_size) { + parsed = varint.parse(record_data[header_pos..]); + try serial_types.append(allocator, parsed.value); + header_pos += parsed.len; + } + + // Check WHERE condition if present + if (where_column_idx) |where_idx| { + if (where_value) |expected_value| { + // Calculate position for WHERE column + var where_body_pos: usize = header_size; + for (0..where_idx) |col| { + if (col >= serial_types.items.len) break; + const st = serial_types.items[col]; + if (st >= 13 and (st % 2) == 1) { + where_body_pos += (st - 13) / 2; + } else if (st >= 1 and st <= 6) { + const int_result = record.readInt(record_data[where_body_pos..], st); + where_body_pos += int_result.len; + } + } + + // Read WHERE column value + var matches = false; + if (where_idx < serial_types.items.len) { + const st = serial_types.items[where_idx]; + if (st >= 13 and (st % 2) == 1) { + const str_result = record.readString(record_data[where_body_pos..], st); + matches = std.mem.eql(u8, str_result.value, expected_value); + } else if (st >= 1 and st <= 6) { + const int_result = record.readInt(record_data[where_body_pos..], st); + // Try to parse expected_value as integer + const expected_int = std.fmt.parseInt(i64, expected_value, 10) catch 0; + matches = int_result.value == expected_int; + } + } + + // Skip this row if it doesn't match + if (!matches) continue; + } + } + + // For each requested column + for (column_indices, 0..) |column_idx, col_num| { + // Calculate position for this column + var body_pos: usize = header_size; + for (0..column_idx) |col| { + if (col >= serial_types.items.len) break; + const st = serial_types.items[col]; + if (st >= 13 and (st % 2) == 1) { + body_pos += (st - 13) / 2; + } else if (st >= 1 and st <= 6) { + const int_result = record.readInt(record_data[body_pos..], st); + body_pos += int_result.len; + } + } + + // Print separator if not first column + if (col_num > 0) { + try stdout.print("|", .{}); + } + + // Read and print the column value + if (column_idx < serial_types.items.len) { + const st = serial_types.items[column_idx]; + if (st >= 13 and (st % 2) == 1) { + const str_result = record.readString(record_data[body_pos..], st); + try stdout.print("{s}", .{str_result.value}); + } else if (st >= 1 and st <= 6) { + const int_result = record.readInt(record_data[body_pos..], st); + try stdout.print("{}", .{int_result.value}); + } + } + } + try stdout.print("\n", .{}); + } +} |