113 lines
3.4 KiB
Go
113 lines
3.4 KiB
Go
package config
|
|
|
|
import (
|
|
"fmt"
|
|
"gorm.io/driver/postgres"
|
|
)
|
|
import (
|
|
"github.com/spf13/viper"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
var DB *gorm.DB
|
|
var App AppConfig
|
|
|
|
type DBConfig struct {
|
|
Host string `mapstructure:"host" yaml:"host"`
|
|
Port string `mapstructure:"port" yaml:"port"`
|
|
User string `mapstructure:"user" yaml:"user"`
|
|
Password string `mapstructure:"password" yaml:"password"`
|
|
Name string `mapstructure:"name" yaml:"name"`
|
|
}
|
|
|
|
type AIConfig struct {
|
|
BaseURL string `mapstructure:"base_url" yaml:"base_url"`
|
|
APIKey string `mapstructure:"api_key" yaml:"api_key"`
|
|
Model string `mapstructure:"model" yaml:"model"`
|
|
}
|
|
|
|
type MQTTConfig struct {
|
|
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
|
Host string `mapstructure:"host" yaml:"host"`
|
|
Port int `mapstructure:"port" yaml:"port"`
|
|
Username string `mapstructure:"username" yaml:"username"`
|
|
Password string `mapstructure:"password" yaml:"password"`
|
|
ClientIDPrefix string `mapstructure:"client_id_prefix" yaml:"client_id_prefix"`
|
|
Region string `mapstructure:"region" yaml:"region"`
|
|
UseTLS bool `mapstructure:"use_tls" yaml:"use_tls"`
|
|
QoS int `mapstructure:"qos" yaml:"qos"`
|
|
EnableMeasurementSubscriptions bool `mapstructure:"enable_measurement_subscriptions" yaml:"enable_measurement_subscriptions"`
|
|
EnableTrainingEventSubscription bool `mapstructure:"enable_training_event_subscription" yaml:"enable_training_event_subscription"`
|
|
}
|
|
|
|
type AppConfig struct {
|
|
DB DBConfig `mapstructure:"database" yaml:"database"`
|
|
AI AIConfig `mapstructure:"ai" yaml:"ai"`
|
|
MQTT MQTTConfig `mapstructure:"mqtt" yaml:"mqtt"`
|
|
Swagger SwaggerConfig `mapstructure:"swagger" yaml:"swagger"`
|
|
}
|
|
|
|
type SwaggerConfig struct {
|
|
Enabled bool `mapstructure:"enabled" yaml:"enabled"`
|
|
}
|
|
|
|
func InitConfig() {
|
|
viper.AddConfigPath("./")
|
|
viper.SetConfigName("config")
|
|
viper.SetConfigType("yaml")
|
|
if err := viper.ReadInConfig(); err != nil {
|
|
panic("Failed to read config: " + err.Error())
|
|
}
|
|
if err := viper.Unmarshal(&App); err != nil {
|
|
panic("Failed to parse config: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func ConnectDB() {
|
|
if err := validateDBConfig(App.DB); err != nil {
|
|
panic("Failed to connect database: " + err.Error())
|
|
}
|
|
|
|
dsn := "host=" + App.DB.Host +
|
|
" user=" + App.DB.User +
|
|
" password=" + App.DB.Password +
|
|
" dbname=" + App.DB.Name +
|
|
" port=" + App.DB.Port +
|
|
" sslmode=disable"
|
|
|
|
var err error
|
|
DB, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
|
if err != nil {
|
|
panic("Failed to connect database: " + err.Error())
|
|
}
|
|
}
|
|
|
|
func validateDBConfig(cfg DBConfig) error {
|
|
if cfg.Host == "" {
|
|
return fmt.Errorf("missing config: database.host")
|
|
}
|
|
if cfg.Port == "" {
|
|
return fmt.Errorf("missing config: database.port")
|
|
}
|
|
if cfg.User == "" {
|
|
return fmt.Errorf("missing config: database.user")
|
|
}
|
|
if cfg.Name == "" {
|
|
return fmt.Errorf("missing config: database.name")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func GetAIConfig() (baseURL, apiKey, model string, err error) {
|
|
if App.AI.BaseURL == "" {
|
|
return "", "", "", fmt.Errorf("missing config: ai.base_url")
|
|
}
|
|
if App.AI.APIKey == "" {
|
|
return "", "", "", fmt.Errorf("missing config: ai.api_key")
|
|
}
|
|
if App.AI.Model == "" {
|
|
return "", "", "", fmt.Errorf("missing config: ai.model")
|
|
}
|
|
return App.AI.BaseURL, App.AI.APIKey, App.AI.Model, nil
|
|
}
|