summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rwxr-xr-xsrc/main.zig35
-rw-r--r--src/schema.zig218
-rw-r--r--src/varint.zig8
3 files changed, 254 insertions, 7 deletions
diff --git a/src/main.zig b/src/main.zig
index 780c70c..c06bd66 100755
--- a/src/main.zig
+++ b/src/main.zig
@@ -37,6 +37,41 @@ pub fn main() !void {
try stdout.print("number of tables: {}\n", .{num_tables});
} else if (std.mem.eql(u8, args[2], ".tables")) {
try tables.showTables(allocator, &file, page_size, stdout);
+ } else if (std.mem.startsWith(u8, args[2], "SELECT") or std.mem.startsWith(u8, args[2], "select")) {
+ // Parse SELECT query: "SELECT column FROM table"
+ var tokens = std.mem.tokenizeScalar(u8, args[2], ' ');
+
+ // Skip "SELECT"
+ _ = tokens.next();
+
+ // Get column name or aggregate function
+ const column_name = tokens.next() orelse return error.InvalidQuery;
+
+ // Skip "FROM" or "from"
+ _ = tokens.next();
+
+ // Get table name
+ const table_name = tokens.next() orelse return error.InvalidQuery;
+
+ // Check if this is COUNT(*)
+ if (std.mem.indexOf(u8, column_name, "count(") != null or std.mem.indexOf(u8, column_name, "COUNT(") != null) {
+ const rootpage = try schema.getRootpage(allocator, &file, page_size, table_name);
+ const row_count = try schema.countRows(allocator, &file, page_size, rootpage);
+ try stdout.print("{}\n", .{row_count});
+ } else {
+ // Get the CREATE TABLE statement to find column order
+ const create_sql = try schema.getCreateTableSQL(allocator, &file, page_size, table_name);
+ defer allocator.free(create_sql);
+
+ // Find the column index
+ const column_idx = try schema.parseColumnIndex(create_sql, column_name);
+
+ // Get the table's root page
+ const rootpage = try schema.getRootpage(allocator, &file, page_size, table_name);
+
+ // Read and print all rows
+ try schema.readTableRows(allocator, &file, page_size, rootpage, column_idx, stdout);
+ }
} else {
var tokens = std.mem.tokenizeScalar(u8, args[2], ' ');
var last_token: []const u8 = "";
diff --git a/src/schema.zig b/src/schema.zig
index 00bf1f7..ccfc9e0 100644
--- a/src/schema.zig
+++ b/src/schema.zig
@@ -46,11 +46,20 @@ pub fn getRootpage(allocator: std.mem.Allocator, file: *std.fs.File, page_size:
var body_pos: usize = header_size;
- for (0..2) |col| {
- const st = serial_types[col];
- if (st >= 13 and (st % 2) == 1) {
- body_pos += (st - 13) / 2;
- }
+ const st0 = serial_types[0];
+ if (st0 >= 13 and (st0 % 2) == 1) {
+ body_pos += (st0 - 13) / 2;
+ } else if (st0 >= 1 and st0 <= 6) {
+ const r0 = record.readInt(record_data[body_pos..], st0);
+ body_pos += r0.len;
+ }
+
+ const st1 = serial_types[1];
+ if (st1 >= 13 and (st1 % 2) == 1) {
+ body_pos += (st1 - 13) / 2;
+ } else if (st1 >= 1 and st1 <= 6) {
+ const r1 = record.readInt(record_data[body_pos..], st1);
+ body_pos += r1.len;
}
const tbl_name_result = record.readString(record_data[body_pos..], serial_types[2]);
@@ -82,3 +91,202 @@ pub fn countRows(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u1
return 0;
}
+
+pub fn getCreateTableSQL(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, table_name: []const u8) ![]const u8 {
+ var buf: [2]u8 = undefined;
+ _ = try file.seekTo(103);
+ _ = try file.read(&buf);
+ const num_cells = std.mem.readInt(u16, &buf, .big);
+
+ var cell_pointers = try allocator.alloc(u16, num_cells);
+ defer allocator.free(cell_pointers);
+
+ for (0..num_cells) |i| {
+ _ = try file.seekTo(108 + i * 2);
+ _ = try file.read(&buf);
+ cell_pointers[i] = std.mem.readInt(u16, &buf, .big);
+ }
+
+ var page_data = try allocator.alloc(u8, page_size);
+ defer allocator.free(page_data);
+
+ _ = try file.seekTo(0);
+ _ = try file.read(page_data);
+
+ for (0..num_cells) |i| {
+ const cell_data = page_data[cell_pointers[i]..];
+
+ var parsed = varint.parse(cell_data);
+ var pos = parsed.len;
+
+ parsed = varint.parse(cell_data[pos..]);
+ pos += parsed.len;
+
+ const record_data = cell_data[pos..];
+ parsed = varint.parse(record_data);
+ const header_size = parsed.value;
+ var header_pos = parsed.len;
+
+ var serial_types: [5]u64 = undefined;
+ for (0..5) |col| {
+ parsed = varint.parse(record_data[header_pos..]);
+ serial_types[col] = parsed.value;
+ header_pos += parsed.len;
+ }
+
+ var body_pos: usize = header_size;
+
+ const st0 = serial_types[0];
+ if (st0 >= 13 and (st0 % 2) == 1) {
+ body_pos += (st0 - 13) / 2;
+ } else if (st0 >= 1 and st0 <= 6) {
+ const r0 = record.readInt(record_data[body_pos..], st0);
+ body_pos += r0.len;
+ }
+
+ const st1 = serial_types[1];
+ if (st1 >= 13 and (st1 % 2) == 1) {
+ body_pos += (st1 - 13) / 2;
+ } else if (st1 >= 1 and st1 <= 6) {
+ const r1 = record.readInt(record_data[body_pos..], st1);
+ body_pos += r1.len;
+ }
+
+ const tbl_name_result = record.readString(record_data[body_pos..], serial_types[2]);
+ body_pos += tbl_name_result.len;
+
+ if (std.mem.eql(u8, tbl_name_result.value, table_name)) {
+ const rp = record.readInt(record_data[body_pos..], serial_types[3]);
+ body_pos += rp.len;
+
+ const sql_result = record.readString(record_data[body_pos..], serial_types[4]);
+ return try allocator.dupe(u8, sql_result.value);
+ }
+ }
+
+ return error.TableNotFound;
+}
+
+pub fn parseColumnIndex(sql: []const u8, column_name: []const u8) !usize {
+ var paren_idx: ?usize = null;
+ for (sql, 0..) |c, i| {
+ if (c == '(') {
+ paren_idx = i;
+ break;
+ }
+ }
+
+ if (paren_idx == null) return error.InvalidSQL;
+
+ var col_idx: usize = 0;
+ var in_col_name = false;
+ var col_start: usize = paren_idx.? + 1;
+
+ for (sql[paren_idx.? + 1 ..], 0..) |c, i| {
+ const actual_idx = paren_idx.? + 1 + i;
+
+ if (c == ')') break;
+
+ if (c == ' ' or c == '\t' or c == '\n') {
+ if (in_col_name) {
+ const col_name = std.mem.trim(u8, sql[col_start..actual_idx], " \t\n");
+ if (std.mem.eql(u8, col_name, column_name)) {
+ return col_idx;
+ }
+ in_col_name = false;
+ }
+ continue;
+ }
+
+ if (c == ',') {
+ if (in_col_name) {
+ const col_name = std.mem.trim(u8, sql[col_start..actual_idx], " \t\n");
+ if (std.mem.eql(u8, col_name, column_name)) {
+ return col_idx;
+ }
+ }
+ col_idx += 1;
+ in_col_name = false;
+ continue;
+ }
+
+ if (!in_col_name) {
+ col_start = actual_idx;
+ in_col_name = true;
+ }
+ }
+
+ return error.ColumnNotFound;
+}
+
+pub fn readTableRows(allocator: std.mem.Allocator, file: *std.fs.File, page_size: u16, rootpage: u32, column_idx: usize, stdout: anytype) !void {
+ if (rootpage == 0) return;
+
+ const page_offset = (rootpage - 1) * @as(u64, page_size);
+ var page_data = try allocator.alloc(u8, page_size);
+ defer allocator.free(page_data);
+
+ _ = try file.seekTo(page_offset);
+ _ = try file.read(page_data);
+
+ const page_type = page_data[0];
+ if (page_type != 0x0d) return;
+
+ const num_cells = std.mem.readInt(u16, page_data[3..5], .big);
+
+ var cell_pointers = try allocator.alloc(u16, num_cells);
+ defer allocator.free(cell_pointers);
+
+ for (0..num_cells) |i| {
+ const offset = 8 + i * 2;
+ const cell_ptr_bytes = page_data[offset .. offset + 2];
+ cell_pointers[i] = std.mem.readInt(u16, cell_ptr_bytes[0..2], .big);
+ }
+
+ for (0..num_cells) |i| {
+ const cell_data = page_data[cell_pointers[i]..];
+
+ var parsed = varint.parse(cell_data);
+ var pos = parsed.len;
+
+ parsed = varint.parse(cell_data[pos..]);
+ pos += parsed.len;
+
+ const record_data = cell_data[pos..];
+ parsed = varint.parse(record_data);
+ const header_size = parsed.value;
+ var header_pos = parsed.len;
+
+ var serial_types = std.ArrayList(u64){};
+ defer serial_types.deinit(allocator);
+
+ while (header_pos < header_size) {
+ parsed = varint.parse(record_data[header_pos..]);
+ try serial_types.append(allocator, parsed.value);
+ header_pos += parsed.len;
+ }
+
+ var body_pos: usize = header_size;
+ for (0..column_idx) |col| {
+ if (col >= serial_types.items.len) break;
+ const st = serial_types.items[col];
+ if (st >= 13 and (st % 2) == 1) {
+ body_pos += (st - 13) / 2;
+ } else if (st >= 1 and st <= 6) {
+ const int_result = record.readInt(record_data[body_pos..], st);
+ body_pos += int_result.len;
+ }
+ }
+
+ if (column_idx < serial_types.items.len) {
+ const st = serial_types.items[column_idx];
+ if (st >= 13 and (st % 2) == 1) {
+ const str_result = record.readString(record_data[body_pos..], st);
+ try stdout.print("{s}\n", .{str_result.value});
+ } else if (st >= 1 and st <= 6) {
+ const int_result = record.readInt(record_data[body_pos..], st);
+ try stdout.print("{}\n", .{int_result.value});
+ }
+ }
+ }
+}
diff --git a/src/varint.zig b/src/varint.zig
index 45207cf..3a07442 100644
--- a/src/varint.zig
+++ b/src/varint.zig
@@ -6,8 +6,12 @@ pub fn parse(data: []const u8) struct { value: u64, len: usize } {
while (i < data.len and i < 9) : (i += 1) {
const byte = data[i];
- result |= @as(u64, byte & 0x7f) << @as(u6, @intCast(i * 7));
- if ((byte & 0x80) == 0) return .{ .value = result, .len = i + 1 };
+ if ((byte & 0x80) != 0) {
+ result = (result << 7) | @as(u64, byte & 0x7f);
+ } else {
+ result = (result << 7) | @as(u64, byte);
+ return .{ .value = result, .len = i + 1 };
+ }
}
return .{ .value = result, .len = i };