Files

273 lines
10 KiB
C#

using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Data.Sqlite;
namespace AxCopilot.Services.Agent;
public class SqlTool : IAgentTool
{
public string Name => "sql_tool";
public string Description => "Execute SQL queries on local SQLite database files. Actions: 'query' — run SELECT query and return results as table; 'execute' — run INSERT/UPDATE/DELETE and return affected rows; 'schema' — show database schema (tables, columns, types); 'tables' — list all tables in the database.";
public ToolParameterSchema Parameters
{
get
{
ToolParameterSchema toolParameterSchema = new ToolParameterSchema();
Dictionary<string, ToolProperty> dictionary = new Dictionary<string, ToolProperty>();
ToolProperty obj = new ToolProperty
{
Type = "string",
Description = "Action to perform"
};
int num = 4;
List<string> list = new List<string>(num);
CollectionsMarshal.SetCount(list, num);
Span<string> span = CollectionsMarshal.AsSpan(list);
span[0] = "query";
span[1] = "execute";
span[2] = "schema";
span[3] = "tables";
obj.Enum = list;
dictionary["action"] = obj;
dictionary["db_path"] = new ToolProperty
{
Type = "string",
Description = "Path to SQLite database file (.db, .sqlite, .sqlite3)"
};
dictionary["sql"] = new ToolProperty
{
Type = "string",
Description = "SQL query to execute (for query/execute actions)"
};
dictionary["max_rows"] = new ToolProperty
{
Type = "string",
Description = "Maximum rows to return (default: 100, max: 1000)"
};
toolParameterSchema.Properties = dictionary;
num = 2;
List<string> list2 = new List<string>(num);
CollectionsMarshal.SetCount(list2, num);
Span<string> span2 = CollectionsMarshal.AsSpan(list2);
span2[0] = "action";
span2[1] = "db_path";
toolParameterSchema.Required = list2;
return toolParameterSchema;
}
}
public Task<ToolResult> ExecuteAsync(JsonElement args, AgentContext context, CancellationToken ct = default(CancellationToken))
{
string text = args.GetProperty("action").GetString() ?? "";
string text2 = args.GetProperty("db_path").GetString() ?? "";
if (!Path.IsPathRooted(text2))
{
text2 = Path.Combine(context.WorkFolder, text2);
}
if (!File.Exists(text2))
{
return Task.FromResult(ToolResult.Fail("Database file not found: " + text2));
}
try
{
string connectionString = "Data Source=" + text2 + ";Mode=ReadOnly";
if (text == "execute")
{
connectionString = "Data Source=" + text2;
}
using SqliteConnection sqliteConnection = new SqliteConnection(connectionString);
sqliteConnection.Open();
if (1 == 0)
{
}
ToolResult result = text switch
{
"query" => QueryAction(sqliteConnection, args),
"execute" => ExecuteAction(sqliteConnection, args),
"schema" => SchemaAction(sqliteConnection),
"tables" => TablesAction(sqliteConnection),
_ => ToolResult.Fail("Unknown action: " + text),
};
if (1 == 0)
{
}
return Task.FromResult(result);
}
catch (Exception ex)
{
return Task.FromResult(ToolResult.Fail("SQL 오류: " + ex.Message));
}
}
private static ToolResult QueryAction(SqliteConnection conn, JsonElement args)
{
if (!args.TryGetProperty("sql", out var value))
{
return ToolResult.Fail("'sql' parameter is required for query action");
}
string text = value.GetString() ?? "";
if (!text.TrimStart().StartsWith("SELECT", StringComparison.OrdinalIgnoreCase) && !text.TrimStart().StartsWith("WITH", StringComparison.OrdinalIgnoreCase) && !text.TrimStart().StartsWith("PRAGMA", StringComparison.OrdinalIgnoreCase))
{
return ToolResult.Fail("Query action only allows SELECT/WITH/PRAGMA statements. Use 'execute' for modifications.");
}
JsonElement value2;
int result;
int num = ((args.TryGetProperty("max_rows", out value2) && int.TryParse(value2.GetString(), out result)) ? Math.Min(result, 1000) : 100);
using SqliteCommand sqliteCommand = conn.CreateCommand();
sqliteCommand.CommandText = text;
using SqliteDataReader sqliteDataReader = sqliteCommand.ExecuteReader();
StringBuilder stringBuilder = new StringBuilder();
int fieldCount = sqliteDataReader.FieldCount;
string[] array = new string[fieldCount];
for (int i = 0; i < fieldCount; i++)
{
array[i] = sqliteDataReader.GetName(i);
}
stringBuilder.AppendLine(string.Join(" | ", array));
stringBuilder.AppendLine(new string('-', array.Sum((string c) => c.Length + 3)));
int num2 = 0;
while (sqliteDataReader.Read() && num2 < num)
{
string[] array2 = new string[fieldCount];
for (int num3 = 0; num3 < fieldCount; num3++)
{
array2[num3] = (sqliteDataReader.IsDBNull(num3) ? "NULL" : (sqliteDataReader.GetValue(num3)?.ToString() ?? ""));
}
stringBuilder.AppendLine(string.Join(" | ", array2));
num2++;
}
if (num2 == 0)
{
return ToolResult.Ok("Query returned 0 rows.");
}
string text2 = stringBuilder.ToString();
if (text2.Length > 8000)
{
text2 = text2.Substring(0, 8000) + "\n... (truncated)";
}
return ToolResult.Ok($"Rows: {num2}" + ((num2 >= num) ? $" (limited to {num})" : "") + "\n\n" + text2);
}
private static ToolResult ExecuteAction(SqliteConnection conn, JsonElement args)
{
if (!args.TryGetProperty("sql", out var value))
{
return ToolResult.Fail("'sql' parameter is required for execute action");
}
string text = value.GetString() ?? "";
string text2 = text.TrimStart().ToUpperInvariant();
if (text2.StartsWith("DROP DATABASE") || text2.StartsWith("ATTACH") || text2.StartsWith("DETACH"))
{
return ToolResult.Fail("Security: DROP DATABASE, ATTACH, DETACH are not allowed.");
}
using SqliteCommand sqliteCommand = conn.CreateCommand();
sqliteCommand.CommandText = text;
int value2 = sqliteCommand.ExecuteNonQuery();
return ToolResult.Ok($"✓ {value2} row(s) affected");
}
private static ToolResult SchemaAction(SqliteConnection conn)
{
StringBuilder stringBuilder = new StringBuilder();
using SqliteCommand sqliteCommand = conn.CreateCommand();
sqliteCommand.CommandText = "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name";
using SqliteDataReader sqliteDataReader = sqliteCommand.ExecuteReader();
List<string> list = new List<string>();
while (sqliteDataReader.Read())
{
list.Add(sqliteDataReader.GetString(0));
}
sqliteDataReader.Close();
foreach (string item in list)
{
StringBuilder stringBuilder2 = stringBuilder;
StringBuilder stringBuilder3 = stringBuilder2;
StringBuilder.AppendInterpolatedStringHandler handler = new StringBuilder.AppendInterpolatedStringHandler(3, 1, stringBuilder2);
handler.AppendLiteral("## ");
handler.AppendFormatted(item);
stringBuilder3.AppendLine(ref handler);
using SqliteCommand sqliteCommand2 = conn.CreateCommand();
sqliteCommand2.CommandText = "PRAGMA table_info(\"" + item + "\")";
using SqliteDataReader sqliteDataReader2 = sqliteCommand2.ExecuteReader();
stringBuilder2 = stringBuilder;
StringBuilder stringBuilder4 = stringBuilder2;
handler = new StringBuilder.AppendInterpolatedStringHandler(5, 6, stringBuilder2);
handler.AppendFormatted<string>("#", -4);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>("Name", -25);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>("Type", -15);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>("NotNull", -8);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>("Default", -15);
handler.AppendLiteral(" ");
handler.AppendFormatted("PK");
stringBuilder4.AppendLine(ref handler);
while (sqliteDataReader2.Read())
{
stringBuilder2 = stringBuilder;
StringBuilder stringBuilder5 = stringBuilder2;
handler = new StringBuilder.AppendInterpolatedStringHandler(5, 6, stringBuilder2);
handler.AppendFormatted(sqliteDataReader2.GetInt32(0), -4);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>(sqliteDataReader2.GetString(1), -25);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>(sqliteDataReader2.GetString(2), -15);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>((sqliteDataReader2.GetInt32(3) == 1) ? "YES" : "", -8);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>(sqliteDataReader2.IsDBNull(4) ? "" : sqliteDataReader2.GetString(4), -15);
handler.AppendLiteral(" ");
handler.AppendFormatted((sqliteDataReader2.GetInt32(5) > 0) ? "PK" : "");
stringBuilder5.AppendLine(ref handler);
}
sqliteDataReader2.Close();
stringBuilder.AppendLine();
}
return ToolResult.Ok(stringBuilder.ToString());
}
private static ToolResult TablesAction(SqliteConnection conn)
{
using SqliteCommand sqliteCommand = conn.CreateCommand();
sqliteCommand.CommandText = "\n SELECT m.name, m.type,\n (SELECT count(*) FROM pragma_table_info(m.name)) as col_count\n FROM sqlite_master m\n WHERE m.type IN ('table','view')\n ORDER BY m.type, m.name";
using SqliteDataReader sqliteDataReader = sqliteCommand.ExecuteReader();
StringBuilder stringBuilder = new StringBuilder();
StringBuilder stringBuilder2 = stringBuilder;
StringBuilder stringBuilder3 = stringBuilder2;
StringBuilder.AppendInterpolatedStringHandler handler = new StringBuilder.AppendInterpolatedStringHandler(2, 3, stringBuilder2);
handler.AppendFormatted<string>("Name", -30);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>("Type", -8);
handler.AppendLiteral(" ");
handler.AppendFormatted("Columns");
stringBuilder3.AppendLine(ref handler);
stringBuilder.AppendLine(new string('-', 50));
int num = 0;
while (sqliteDataReader.Read())
{
stringBuilder2 = stringBuilder;
StringBuilder stringBuilder4 = stringBuilder2;
handler = new StringBuilder.AppendInterpolatedStringHandler(2, 3, stringBuilder2);
handler.AppendFormatted<string>(sqliteDataReader.GetString(0), -30);
handler.AppendLiteral(" ");
handler.AppendFormatted<string>(sqliteDataReader.GetString(1), -8);
handler.AppendLiteral(" ");
handler.AppendFormatted(sqliteDataReader.GetInt32(2));
stringBuilder4.AppendLine(ref handler);
num++;
}
return ToolResult.Ok($"Found {num} tables/views:\n\n{stringBuilder}");
}
}