diff --git a/cmd/export_sql/main.go b/cmd/export_sql/main.go new file mode 100644 index 0000000..4a40c94 --- /dev/null +++ b/cmd/export_sql/main.go @@ -0,0 +1,323 @@ +package main + +import ( + "database/sql" + "encoding/hex" + "flag" + "fmt" + "hr_receiver/config" + "io" + "os" + "path/filepath" + "strings" + "time" +) + +var defaultTables = []string{"users", "user_region_bindings", "kindergartens"} + +func main() { + var ( + tablesArg string + outputArg string + schemaArg string + ) + + flag.StringVar(&tablesArg, "tables", "", "Comma-separated table names to export. Default: users,user_region_bindings,kindergartens") + flag.StringVar(&outputArg, "output", "", "Output .sql file path. Default: export_.sql") + flag.StringVar(&schemaArg, "schema", "public", "Database schema name") + flag.Parse() + + tables := resolveTables(tablesArg, flag.Args()) + if len(tables) == 0 { + fmt.Fprintln(os.Stderr, "no tables specified") + os.Exit(1) + } + + config.InitConfig() + config.ConnectDB() + + outputPath := outputArg + if outputPath == "" { + outputPath = fmt.Sprintf("export_%s.sql", time.Now().Format("20060102_150405")) + } + + if err := os.MkdirAll(filepath.Dir(outputPath), 0o755); err != nil && filepath.Dir(outputPath) != "." { + fmt.Fprintf(os.Stderr, "create output directory failed: %v\n", err) + os.Exit(1) + } + + file, err := os.Create(outputPath) + if err != nil { + fmt.Fprintf(os.Stderr, "create output file failed: %v\n", err) + os.Exit(1) + } + defer file.Close() + + if err := exportTables(file, schemaArg, tables); err != nil { + fmt.Fprintf(os.Stderr, "export failed: %v\n", err) + os.Exit(1) + } + + fmt.Printf("exported %d table(s) to %s\n", len(tables), outputPath) +} + +func resolveTables(tablesArg string, positional []string) []string { + if len(positional) > 0 { + return dedupeNonEmpty(positional) + } + if strings.TrimSpace(tablesArg) == "" { + return append([]string(nil), defaultTables...) + } + return dedupeNonEmpty(strings.Split(tablesArg, ",")) +} + +func dedupeNonEmpty(items []string) []string { + result := make([]string, 0, len(items)) + seen := make(map[string]struct{}, len(items)) + for _, item := range items { + value := strings.TrimSpace(item) + if value == "" { + continue + } + if _, ok := seen[value]; ok { + continue + } + seen[value] = struct{}{} + result = append(result, value) + } + return result +} + +func exportTables(w io.Writer, schema string, tables []string) error { + db, err := config.DB.DB() + if err != nil { + return err + } + + header := []string{ + "-- Generated by cmd/export_sql", + fmt.Sprintf("-- Time: %s", time.Now().Format(time.RFC3339)), + fmt.Sprintf("-- Schema: %s", schema), + fmt.Sprintf("-- Tables: %s", strings.Join(tables, ", ")), + "", + } + if _, err := io.WriteString(w, strings.Join(header, "\n")); err != nil { + return err + } + + for _, table := range tables { + columns, err := getColumns(db, schema, table) + if err != nil { + return fmt.Errorf("read columns for %s: %w", table, err) + } + if len(columns) == 0 { + return fmt.Errorf("table %s not found or has no columns", table) + } + + orderBy, err := getPrimaryKeyColumns(db, schema, table) + if err != nil { + return fmt.Errorf("read primary key for %s: %w", table, err) + } + + if _, err := fmt.Fprintf(w, "-- Table: %s\n", table); err != nil { + return err + } + + if err := exportTableData(w, db, schema, table, columns, orderBy); err != nil { + return err + } + + if _, err := io.WriteString(w, "\n"); err != nil { + return err + } + } + + return nil +} + +func getColumns(db *sql.DB, schema, table string) ([]string, error) { + const query = ` +SELECT column_name +FROM information_schema.columns +WHERE table_schema = $1 AND table_name = $2 +ORDER BY ordinal_position +` + rows, err := db.Query(query, schema, table) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var column string + if err := rows.Scan(&column); err != nil { + return nil, err + } + columns = append(columns, column) + } + return columns, rows.Err() +} + +func getPrimaryKeyColumns(db *sql.DB, schema, table string) ([]string, error) { + const query = ` +SELECT a.attname +FROM pg_index i +JOIN pg_class c ON c.oid = i.indrelid +JOIN pg_namespace n ON n.oid = c.relnamespace +JOIN pg_attribute a ON a.attrelid = c.oid AND a.attnum = ANY(i.indkey) +WHERE i.indisprimary + AND n.nspname = $1 + AND c.relname = $2 +ORDER BY array_position(i.indkey, a.attnum) +` + rows, err := db.Query(query, schema, table) + if err != nil { + return nil, err + } + defer rows.Close() + + var columns []string + for rows.Next() { + var column string + if err := rows.Scan(&column); err != nil { + return nil, err + } + columns = append(columns, column) + } + return columns, rows.Err() +} + +func exportTableData(w io.Writer, db *sql.DB, schema, table string, columns, orderBy []string) error { + query := fmt.Sprintf( + `SELECT %s FROM %s.%s`, + joinIdentifiers(columns), + quoteIdentifier(schema), + quoteIdentifier(table), + ) + if len(orderBy) > 0 { + query += " ORDER BY " + joinIdentifiers(orderBy) + } + + rows, err := db.Query(query) + if err != nil { + return fmt.Errorf("query rows for %s: %w", table, err) + } + defer rows.Close() + + values := make([]any, len(columns)) + scanTargets := make([]any, len(columns)) + for i := range values { + scanTargets[i] = &values[i] + } + + rowCount := 0 + for rows.Next() { + clear(values) + if err := rows.Scan(scanTargets...); err != nil { + return fmt.Errorf("scan row for %s: %w", table, err) + } + + literals := make([]string, len(columns)) + for i, value := range values { + literals[i] = toSQLLiteral(value) + } + + stmt := fmt.Sprintf( + "INSERT INTO %s.%s (%s) VALUES (%s);\n", + quoteIdentifier(schema), + quoteIdentifier(table), + joinIdentifiers(columns), + strings.Join(literals, ", "), + ) + if _, err := io.WriteString(w, stmt); err != nil { + return err + } + rowCount++ + } + if err := rows.Err(); err != nil { + return err + } + + if rowCount == 0 { + _, err := io.WriteString(w, "-- no rows\n") + return err + } + + return nil +} + +func joinIdentifiers(columns []string) string { + quoted := make([]string, len(columns)) + for i, column := range columns { + quoted[i] = quoteIdentifier(column) + } + return strings.Join(quoted, ", ") +} + +func quoteIdentifier(value string) string { + return `"` + strings.ReplaceAll(value, `"`, `""`) + `"` +} + +func toSQLLiteral(value any) string { + switch v := value.(type) { + case nil: + return "NULL" + case string: + return quoteString(v) + case []byte: + if len(v) == 0 { + return quoteString("") + } + if isPrintableUTF8(v) { + return quoteString(string(v)) + } + return quoteString(`\x` + hex.EncodeToString(v)) + case bool: + if v { + return "TRUE" + } + return "FALSE" + case int: + return fmt.Sprintf("%d", v) + case int8: + return fmt.Sprintf("%d", v) + case int16: + return fmt.Sprintf("%d", v) + case int32: + return fmt.Sprintf("%d", v) + case int64: + return fmt.Sprintf("%d", v) + case uint: + return fmt.Sprintf("%d", v) + case uint8: + return fmt.Sprintf("%d", v) + case uint16: + return fmt.Sprintf("%d", v) + case uint32: + return fmt.Sprintf("%d", v) + case uint64: + return fmt.Sprintf("%d", v) + case float32: + return fmt.Sprintf("%g", v) + case float64: + return fmt.Sprintf("%g", v) + case time.Time: + return quoteString(v.Format(time.RFC3339Nano)) + default: + return quoteString(fmt.Sprint(v)) + } +} + +func quoteString(value string) string { + return "'" + strings.ReplaceAll(value, "'", "''") + "'" +} + +func isPrintableUTF8(data []byte) bool { + for _, b := range data { + if b == 0 { + return false + } + } + return true +} diff --git a/config.sample.yaml b/config.sample.yaml index ea849b3..c607f2f 100644 --- a/config.sample.yaml +++ b/config.sample.yaml @@ -1,3 +1,6 @@ +server: + port: 8081 + database: host: localhost #when use docker change to "db" port: 5432 diff --git a/controllers/ai.go b/controllers/ai.go index 7c4d3b9..ed02abc 100644 --- a/controllers/ai.go +++ b/controllers/ai.go @@ -256,7 +256,7 @@ func callAIForAnalysis(prompt string) (*aiAnalysisResult, error) { }, Temperature: 0.6, TopP: 0.6, - MaxCompletionTokens: 4000, + MaxCompletionTokens: 8000, }, ) if err != nil { @@ -472,7 +472,7 @@ func (tc *TrainingController) streamAIAnalysis(c *gin.Context, prompt string, }, Temperature: 0.6, TopP: 0.6, - MaxCompletionTokens: 4000, + MaxCompletionTokens: 8000, Stream: true, StreamOptions: &openai.StreamOptions{ IncludeUsage: true, diff --git a/scripts/export_db.ps1 b/scripts/export_db.ps1 new file mode 100644 index 0000000..2472b80 --- /dev/null +++ b/scripts/export_db.ps1 @@ -0,0 +1,12 @@ +$ErrorActionPreference = "Stop" + +$scriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path +$projectRoot = Resolve-Path (Join-Path $scriptDir "..") + +Push-Location $projectRoot +try { + go run ./cmd/export_sql @args +} +finally { + Pop-Location +}