using System.IO; using System.Text; using System.Text.Json; using Microsoft.Data.Sqlite; namespace AxCopilot.Services.Agent; /// /// SQLite 데이터베이스 쿼리 실행 도구. /// 로컬 .db/.sqlite 파일에 대해 SELECT/INSERT/UPDATE/DELETE 쿼리를 실행합니다. /// public class SqlTool : IAgentTool { public string Name => "sql_tool"; public string Description => "Execute SQL queries on local SQLite database files. Actions: " + "'query' — run SELECT query and return results as table; " + "'execute' — run INSERT/UPDATE/DELETE and return affected rows; " + "'schema' — show database schema (tables, columns, types); " + "'tables' — list all tables in the database."; public ToolParameterSchema Parameters => new() { Properties = new() { ["action"] = new() { Type = "string", Description = "Action to perform", Enum = ["query", "execute", "schema", "tables"], }, ["db_path"] = new() { Type = "string", Description = "Path to SQLite database file (.db, .sqlite, .sqlite3)", }, ["sql"] = new() { Type = "string", Description = "SQL query to execute (for query/execute actions)", }, ["max_rows"] = new() { Type = "string", Description = "Maximum rows to return (default: 100, max: 1000)", }, }, Required = ["action", "db_path"], }; public Task ExecuteAsync(JsonElement args, AgentContext context, CancellationToken ct = default) { var action = args.GetProperty("action").GetString() ?? ""; var dbPath = args.GetProperty("db_path").GetString() ?? ""; if (!Path.IsPathRooted(dbPath)) dbPath = Path.Combine(context.WorkFolder, dbPath); if (!File.Exists(dbPath)) return Task.FromResult(ToolResult.Fail($"Database file not found: {dbPath}")); try { var connStr = $"Data Source={dbPath};Mode=ReadOnly"; // execute 액션은 ReadWrite 필요 if (action == "execute") connStr = $"Data Source={dbPath}"; using var conn = new SqliteConnection(connStr); conn.Open(); return Task.FromResult(action switch { "query" => QueryAction(conn, args), "execute" => ExecuteAction(conn, args), "schema" => SchemaAction(conn), "tables" => TablesAction(conn), _ => ToolResult.Fail($"Unknown action: {action}"), }); } catch (Exception ex) { return Task.FromResult(ToolResult.Fail($"SQL 오류: {ex.Message}")); } } private static ToolResult QueryAction(SqliteConnection conn, JsonElement args) { if (!args.TryGetProperty("sql", out var sqlProp)) return ToolResult.Fail("'sql' parameter is required for query action"); var sql = sqlProp.GetString() ?? ""; // SELECT만 허용 if (!sql.TrimStart().StartsWith("SELECT", StringComparison.OrdinalIgnoreCase) && !sql.TrimStart().StartsWith("WITH", StringComparison.OrdinalIgnoreCase) && !sql.TrimStart().StartsWith("PRAGMA", StringComparison.OrdinalIgnoreCase)) return ToolResult.Fail("Query action only allows SELECT/WITH/PRAGMA statements. Use 'execute' for modifications."); var maxRows = args.TryGetProperty("max_rows", out var mr) && int.TryParse(mr.GetString(), out var mrv) ? Math.Min(mrv, 1000) : 100; using var cmd = conn.CreateCommand(); cmd.CommandText = sql; using var reader = cmd.ExecuteReader(); var sb = new StringBuilder(); var colCount = reader.FieldCount; // 헤더 var colNames = new string[colCount]; for (var i = 0; i < colCount; i++) colNames[i] = reader.GetName(i); sb.AppendLine(string.Join(" | ", colNames)); sb.AppendLine(new string('-', colNames.Sum(c => c.Length + 3))); // 행 var rowCount = 0; while (reader.Read() && rowCount < maxRows) { var values = new string[colCount]; for (var i = 0; i < colCount; i++) values[i] = reader.IsDBNull(i) ? "NULL" : reader.GetValue(i)?.ToString() ?? ""; sb.AppendLine(string.Join(" | ", values)); rowCount++; } if (rowCount == 0) return ToolResult.Ok("Query returned 0 rows."); var result = sb.ToString(); if (result.Length > 8000) result = result[..8000] + "\n... (truncated)"; return ToolResult.Ok($"Rows: {rowCount}" + (rowCount >= maxRows ? $" (limited to {maxRows})" : "") + $"\n\n{result}"); } private static ToolResult ExecuteAction(SqliteConnection conn, JsonElement args) { if (!args.TryGetProperty("sql", out var sqlProp)) return ToolResult.Fail("'sql' parameter is required for execute action"); var sql = sqlProp.GetString() ?? ""; // DDL/DML만 허용 (DROP DATABASE 등 위험 명령 차단) var trimmed = sql.TrimStart().ToUpperInvariant(); if (trimmed.StartsWith("DROP DATABASE") || trimmed.StartsWith("ATTACH") || trimmed.StartsWith("DETACH")) return ToolResult.Fail("Security: DROP DATABASE, ATTACH, DETACH are not allowed."); using var cmd = conn.CreateCommand(); cmd.CommandText = sql; var affected = cmd.ExecuteNonQuery(); return ToolResult.Ok($"✓ {affected} row(s) affected"); } private static ToolResult SchemaAction(SqliteConnection conn) { var sb = new StringBuilder(); using var cmd = conn.CreateCommand(); cmd.CommandText = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"; using var reader = cmd.ExecuteReader(); var tables = new List(); while (reader.Read()) tables.Add(reader.GetString(0)); reader.Close(); foreach (var table in tables) { sb.AppendLine($"## {table}"); using var pragmaCmd = conn.CreateCommand(); pragmaCmd.CommandText = $"PRAGMA table_info(\"{table}\")"; using var pragmaReader = pragmaCmd.ExecuteReader(); sb.AppendLine($"{"#",-4} {"Name",-25} {"Type",-15} {"NotNull",-8} {"Default",-15} {"PK"}"); while (pragmaReader.Read()) { sb.AppendLine($"{pragmaReader.GetInt32(0),-4} " + $"{pragmaReader.GetString(1),-25} " + $"{pragmaReader.GetString(2),-15} " + $"{(pragmaReader.GetInt32(3) == 1 ? "YES" : ""),-8} " + $"{(pragmaReader.IsDBNull(4) ? "" : pragmaReader.GetString(4)),-15} " + $"{(pragmaReader.GetInt32(5) > 0 ? "PK" : "")}"); } pragmaReader.Close(); sb.AppendLine(); } return ToolResult.Ok(sb.ToString()); } private static ToolResult TablesAction(SqliteConnection conn) { using var cmd = conn.CreateCommand(); cmd.CommandText = @" SELECT m.name, m.type, (SELECT count(*) FROM pragma_table_info(m.name)) as col_count FROM sqlite_master m WHERE m.type IN ('table','view') ORDER BY m.type, m.name"; using var reader = cmd.ExecuteReader(); var sb = new StringBuilder(); sb.AppendLine($"{"Name",-30} {"Type",-8} {"Columns"}"); sb.AppendLine(new string('-', 50)); var count = 0; while (reader.Read()) { sb.AppendLine($"{reader.GetString(0),-30} {reader.GetString(1),-8} {reader.GetInt32(2)}"); count++; } return ToolResult.Ok($"Found {count} tables/views:\n\n{sb}"); } }