Files
AX-Copilot-Codex/src/AxCopilot/Services/Agent/SqlTool.cs

211 lines
7.9 KiB
C#

using System.IO;
using System.Text;
using System.Text.Json;
using Microsoft.Data.Sqlite;
namespace AxCopilot.Services.Agent;
/// <summary>
/// SQLite 데이터베이스 쿼리 실행 도구.
/// 로컬 .db/.sqlite 파일에 대해 SELECT/INSERT/UPDATE/DELETE 쿼리를 실행합니다.
/// </summary>
public class SqlTool : IAgentTool
{
public string Name => "sql_tool";
public string Description =>
"Execute SQL queries on local SQLite database files. Actions: " +
"'query' — run SELECT query and return results as table; " +
"'execute' — run INSERT/UPDATE/DELETE and return affected rows; " +
"'schema' — show database schema (tables, columns, types); " +
"'tables' — list all tables in the database.";
public ToolParameterSchema Parameters => new()
{
Properties = new()
{
["action"] = new()
{
Type = "string",
Description = "Action to perform",
Enum = ["query", "execute", "schema", "tables"],
},
["db_path"] = new()
{
Type = "string",
Description = "Path to SQLite database file (.db, .sqlite, .sqlite3)",
},
["sql"] = new()
{
Type = "string",
Description = "SQL query to execute (for query/execute actions)",
},
["max_rows"] = new()
{
Type = "string",
Description = "Maximum rows to return (default: 100, max: 1000)",
},
},
Required = ["action", "db_path"],
};
public Task<ToolResult> ExecuteAsync(JsonElement args, AgentContext context, CancellationToken ct = default)
{
var action = args.GetProperty("action").GetString() ?? "";
var dbPath = args.GetProperty("db_path").GetString() ?? "";
if (!Path.IsPathRooted(dbPath))
dbPath = Path.Combine(context.WorkFolder, dbPath);
if (!File.Exists(dbPath))
return Task.FromResult(ToolResult.Fail($"Database file not found: {dbPath}"));
try
{
var connStr = $"Data Source={dbPath};Mode=ReadOnly";
// execute 액션은 ReadWrite 필요
if (action == "execute")
connStr = $"Data Source={dbPath}";
using var conn = new SqliteConnection(connStr);
conn.Open();
return Task.FromResult(action switch
{
"query" => QueryAction(conn, args),
"execute" => ExecuteAction(conn, args),
"schema" => SchemaAction(conn),
"tables" => TablesAction(conn),
_ => ToolResult.Fail($"Unknown action: {action}"),
});
}
catch (Exception ex)
{
return Task.FromResult(ToolResult.Fail($"SQL 오류: {ex.Message}"));
}
}
private static ToolResult QueryAction(SqliteConnection conn, JsonElement args)
{
if (!args.TryGetProperty("sql", out var sqlProp))
return ToolResult.Fail("'sql' parameter is required for query action");
var sql = sqlProp.GetString() ?? "";
// SELECT만 허용
if (!sql.TrimStart().StartsWith("SELECT", StringComparison.OrdinalIgnoreCase) &&
!sql.TrimStart().StartsWith("WITH", StringComparison.OrdinalIgnoreCase) &&
!sql.TrimStart().StartsWith("PRAGMA", StringComparison.OrdinalIgnoreCase))
return ToolResult.Fail("Query action only allows SELECT/WITH/PRAGMA statements. Use 'execute' for modifications.");
var maxRows = args.TryGetProperty("max_rows", out var mr) && int.TryParse(mr.GetString(), out var mrv)
? Math.Min(mrv, 1000) : 100;
using var cmd = conn.CreateCommand();
cmd.CommandText = sql;
using var reader = cmd.ExecuteReader();
var sb = new StringBuilder();
var colCount = reader.FieldCount;
// 헤더
var colNames = new string[colCount];
for (var i = 0; i < colCount; i++)
colNames[i] = reader.GetName(i);
sb.AppendLine(string.Join(" | ", colNames));
sb.AppendLine(new string('-', colNames.Sum(c => c.Length + 3)));
// 행
var rowCount = 0;
while (reader.Read() && rowCount < maxRows)
{
var values = new string[colCount];
for (var i = 0; i < colCount; i++)
values[i] = reader.IsDBNull(i) ? "NULL" : reader.GetValue(i)?.ToString() ?? "";
sb.AppendLine(string.Join(" | ", values));
rowCount++;
}
if (rowCount == 0)
return ToolResult.Ok("Query returned 0 rows.");
var result = sb.ToString();
if (result.Length > 8000) result = result[..8000] + "\n... (truncated)";
return ToolResult.Ok($"Rows: {rowCount}" + (rowCount >= maxRows ? $" (limited to {maxRows})" : "") + $"\n\n{result}");
}
private static ToolResult ExecuteAction(SqliteConnection conn, JsonElement args)
{
if (!args.TryGetProperty("sql", out var sqlProp))
return ToolResult.Fail("'sql' parameter is required for execute action");
var sql = sqlProp.GetString() ?? "";
// DDL/DML만 허용 (DROP DATABASE 등 위험 명령 차단)
var trimmed = sql.TrimStart().ToUpperInvariant();
if (trimmed.StartsWith("DROP DATABASE") || trimmed.StartsWith("ATTACH") || trimmed.StartsWith("DETACH"))
return ToolResult.Fail("Security: DROP DATABASE, ATTACH, DETACH are not allowed.");
using var cmd = conn.CreateCommand();
cmd.CommandText = sql;
var affected = cmd.ExecuteNonQuery();
return ToolResult.Ok($"✓ {affected} row(s) affected");
}
private static ToolResult SchemaAction(SqliteConnection conn)
{
var sb = new StringBuilder();
using var cmd = conn.CreateCommand();
cmd.CommandText = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name";
using var reader = cmd.ExecuteReader();
var tables = new List<string>();
while (reader.Read()) tables.Add(reader.GetString(0));
reader.Close();
foreach (var table in tables)
{
sb.AppendLine($"## {table}");
using var pragmaCmd = conn.CreateCommand();
pragmaCmd.CommandText = $"PRAGMA table_info(\"{table}\")";
using var pragmaReader = pragmaCmd.ExecuteReader();
sb.AppendLine($"{"#",-4} {"Name",-25} {"Type",-15} {"NotNull",-8} {"Default",-15} {"PK"}");
while (pragmaReader.Read())
{
sb.AppendLine($"{pragmaReader.GetInt32(0),-4} " +
$"{pragmaReader.GetString(1),-25} " +
$"{pragmaReader.GetString(2),-15} " +
$"{(pragmaReader.GetInt32(3) == 1 ? "YES" : ""),-8} " +
$"{(pragmaReader.IsDBNull(4) ? "" : pragmaReader.GetString(4)),-15} " +
$"{(pragmaReader.GetInt32(5) > 0 ? "PK" : "")}");
}
pragmaReader.Close();
sb.AppendLine();
}
return ToolResult.Ok(sb.ToString());
}
private static ToolResult TablesAction(SqliteConnection conn)
{
using var cmd = conn.CreateCommand();
cmd.CommandText = @"
SELECT m.name, m.type,
(SELECT count(*) FROM pragma_table_info(m.name)) as col_count
FROM sqlite_master m
WHERE m.type IN ('table','view')
ORDER BY m.type, m.name";
using var reader = cmd.ExecuteReader();
var sb = new StringBuilder();
sb.AppendLine($"{"Name",-30} {"Type",-8} {"Columns"}");
sb.AppendLine(new string('-', 50));
var count = 0;
while (reader.Read())
{
sb.AppendLine($"{reader.GetString(0),-30} {reader.GetString(1),-8} {reader.GetInt32(2)}");
count++;
}
return ToolResult.Ok($"Found {count} tables/views:\n\n{sb}");
}
}