diff --git a/main.go b/main.go index 6f28193d6..58e147394 100644 --- a/main.go +++ b/main.go @@ -2,8 +2,8 @@ package main import ( "crypto/tls" + "database/sql" "encoding/json" - "flag" "fmt" "html/template" "io" @@ -11,20 +11,20 @@ import ( "log" "net/http" "os" - "path" - "path/filepath" "regexp" "strings" "sync" "time" + + _ "github.com/lib/pq" ) var ( - keyFile = flag.String("key", "", "Key file") - certFile = flag.String("cert", "", "Certificate file") - dbDir = flag.String("db", "", "Database directory") - port = flag.Int("port", 8443, "Listen port") - tpl *template.Template + keyFile = getEnvDefault("UR_KEY_FILE", "key.pem") + certFile = getEnvDefault("UR_CRT_FILE", "crt.pem") + dbConn = getEnvDefault("UR_DB_URL", "postgres://user:password@localhost/ur?sslmode=disable") + listenAddr = getEnvDefault("UR_LISTEN", "0.0.0.0:8443") + tpl *template.Template ) var funcs = map[string]interface{}{ @@ -32,30 +32,128 @@ var funcs = map[string]interface{}{ "number": number, } +func getEnvDefault(key, def string) string { + if val := os.Getenv(key); val != "" { + return val + } + return def +} + +type report struct { + Received time.Time // Only from DB + + UniqueID string + Version string + LongVersion string + Platform string + NumFolders int + NumDevices int + TotFiles int + FolderMaxFiles int + TotMiB int + FolderMaxMiB int + MemoryUsageMiB int + SHA256Perf float64 + MemorySize int + + Date string +} + +func (r *report) Validate() error { + if r.UniqueID == "" || r.Version == "" || r.Platform == "" { + return fmt.Errorf("missing required field") + } + if len(r.Date) != 8 { + return fmt.Errorf("date not initialized") + } + return nil +} + +func setupDB(db *sql.DB) error { + _, err := db.Exec(`CREATE TABLE IF NOT EXISTS Reports ( + Received TIMESTAMP NOT NULL, + UniqueID VARCHAR(32) NOT NULL, + Version VARCHAR(32) NOT NULL, + LongVersion VARCHAR(256) NOT NULL, + Platform VARCHAR(32) NOT NULL, + NumFolders INTEGER NOT NULL, + NumDevices INTEGER NOT NULL, + TotFiles INTEGER NOT NULL, + FolderMaxFiles INTEGER NOT NULL, + TotMiB INTEGER NOT NULL, + FolderMaxMiB INTEGER NOT NULL, + MemoryUsageMiB INTEGER NOT NULL, + SHA256Perf DOUBLE PRECISION NOT NULL, + MemorySize INTEGER NOT NULL, + Date VARCHAR(8) NOT NULL + )`) + if err != nil { + return err + } + + row := db.QueryRow(`SELECT 'UniqueIDIndex'::regclass`) + if err := row.Scan(nil); err != nil { + _, err = db.Exec(`CREATE UNIQUE INDEX UniqueIDIndex ON Reports (Date, UniqueID)`) + } + + row = db.QueryRow(`SELECT 'ReceivedIndex'::regclass`) + if err := row.Scan(nil); err != nil { + _, err = db.Exec(`CREATE INDEX ReceivedIndex ON Reports (Received)`) + } + + return err +} + +func insertReport(db *sql.DB, r report) error { + _, err := db.Exec(`INSERT INTO Reports VALUES (now(), $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14)`, + r.UniqueID, r.Version, r.LongVersion, r.Platform, r.NumFolders, + r.NumDevices, r.TotFiles, r.FolderMaxFiles, r.TotMiB, r.FolderMaxMiB, + r.MemoryUsageMiB, r.SHA256Perf, r.MemorySize, r.Date) + + return err +} + +type withDBFunc func(*sql.DB, http.ResponseWriter, *http.Request) + +func withDB(db *sql.DB, f withDBFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + f(db, w, r) + }) +} + func main() { - log.SetFlags(log.Lshortfile) + log.SetFlags(log.Ltime | log.Ldate) log.SetOutput(os.Stdout) - flag.Parse() + + // Template fd, err := os.Open("static/index.html") if err != nil { - log.Fatal(err) + log.Fatalln("template:", err) } bs, err := ioutil.ReadAll(fd) if err != nil { - log.Fatal(err) + log.Fatalln("template:", err) } fd.Close() tpl = template.Must(template.New("index.html").Funcs(funcs).Parse(string(bs))) - http.HandleFunc("/", rootHandler) - http.HandleFunc("/newdata", newDataHandler) - http.HandleFunc("/report", reportHandler) - http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) + // DB - cert, err := tls.LoadX509KeyPair(*certFile, *keyFile) + db, err := sql.Open("postgres", dbConn) if err != nil { - log.Fatal(err) + log.Fatalln("database:", err) + } + err = setupDB(db) + if err != nil { + log.Fatalln("database:", err) + } + + // TLS + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + log.Fatalln("tls:", err) } cfg := &tls.Config{ @@ -63,28 +161,32 @@ func main() { SessionTicketsDisabled: true, } - listener, err := tls.Listen("tcp", fmt.Sprintf(":%d", *port), cfg) - if err != nil { - log.Fatal(err) - } + // HTTPS - log.Println("Listening on", listener.Addr()) + listener, err := tls.Listen("tcp", listenAddr, cfg) + if err != nil { + log.Fatalln("https:", err) + } srv := http.Server{ ReadTimeout: 5 * time.Second, WriteTimeout: 5 * time.Second, } + http.HandleFunc("/", withDB(db, rootHandler)) + http.HandleFunc("/newdata", withDB(db, newDataHandler)) + http.Handle("/static/", http.StripPrefix("/static/", http.FileServer(http.Dir("static")))) + err = srv.Serve(listener) if err != nil { - log.Fatal(err) + log.Fatalln("https:", err) } } -func rootHandler(w http.ResponseWriter, r *http.Request) { +func rootHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { if r.URL.Path == "/" || r.URL.Path == "/index.html" { k := timestamp() - rep := getReport(k) + rep := getReport(db, k) w.Header().Set("Content-Type", "text/html; charset=utf-8") err := tpl.Execute(w, rep) @@ -96,126 +198,32 @@ func rootHandler(w http.ResponseWriter, r *http.Request) { } } -func reportHandler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - k := timestamp() - rep := getReport(k) - json.NewEncoder(w).Encode(rep) -} +func newDataHandler(db *sql.DB, w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() -func newDataHandler(w http.ResponseWriter, r *http.Request) { - today := time.Now().Format("20060102") - dir := filepath.Join(*dbDir, today) - ensureDir(dir, 0700) + var rep report + rep.Date = time.Now().UTC().Format("20060102") - var m map[string]interface{} lr := &io.LimitedReader{R: r.Body, N: 10240} - err := json.NewDecoder(lr).Decode(&m) - - if err != nil { - log.Println(err) - http.Error(w, err.Error(), 500) + if err := json.NewDecoder(lr).Decode(&rep); err != nil { + log.Println("json decode:", err) + http.Error(w, "JSON Decode Error", http.StatusInternalServerError) return } - id, ok := m["uniqueID"] - if ok { - idStr, ok := id.(string) - if !ok { - if err != nil { - log.Printf("No ID (type was %T)", id) - http.Error(w, "No ID", 500) - return - } - } - if idStr == "" { - log.Println("No ID (empty)") - http.Error(w, "No ID", 500) - return - } - - // The ID is base64 encoded, so can contain slashes. Replace those with dots instead. - idStr = strings.Replace(idStr, "/", ".", -1) - - f, err := os.Create(path.Join(dir, idStr+".json")) - if err != nil { - log.Println(err) - http.Error(w, err.Error(), 500) - return - } - err = json.NewEncoder(f).Encode(m) - if err != nil { - log.Println(err) - http.Error(w, err.Error(), 500) - return - } - err = f.Close() - if err != nil { - log.Println(err) - http.Error(w, err.Error(), 500) - return - } - - log.Printf("Report from %q", id) - } else { - log.Println("No ID (missing)") - http.Error(w, "No ID", 500) + if err := rep.Validate(); err != nil { + log.Println("validate:", err) + log.Printf("%#v", rep) + http.Error(w, "Validation Error", http.StatusInternalServerError) return } -} -type report struct { - UniqueID string - Version string - Platform string - NumRepos int - NumNodes int - TotFiles int - RepoMaxFiles int - TotMiB int - RepoMaxMiB int - MemoryUsageMiB int - SHA256Perf float64 - MemorySize int -} - -func fileList() ([]string, error) { - files := make(map[string]string) - t0 := time.Now().Add(-24 * time.Hour).Format("20060102") - t1 := time.Now().Format("20060102") - - dir := filepath.Join(*dbDir, t0) - gr, err := filepath.Glob(filepath.Join(dir, "*.json")) - if err != nil { - return nil, err + if err := insertReport(db, rep); err != nil { + log.Println("insert:", err) + log.Printf("%#v", rep) + http.Error(w, "Database Error", http.StatusInternalServerError) + return } - for _, f := range gr { - bn := filepath.Base(f) - files[bn] = f - } - - dir = filepath.Join(*dbDir, t1) - gr, err = filepath.Glob(filepath.Join(dir, "*.json")) - if err != nil { - return nil, err - } - for _, f := range gr { - bn := filepath.Base(f) - files[bn] = f - } - - l := make([]string, 0, len(files)) - for _, f := range files { - si, err := os.Stat(f) - if err != nil { - continue - } - if time.Since(si.ModTime()) < 24*time.Hour { - l = append(l, f) - } - } - - return l, nil } type category struct { @@ -229,7 +237,7 @@ type category struct { var reportCache map[string]interface{} var reportMutex sync.Mutex -func getReport(key string) map[string]interface{} { +func getReport(db *sql.DB, key string) map[string]interface{} { reportMutex.Lock() defer reportMutex.Unlock() @@ -237,17 +245,12 @@ func getReport(key string) map[string]interface{} { return reportCache } - files, err := fileList() - if err != nil { - return nil - } - nodes := 0 var versions []string var platforms []string var oses []string - var numRepos []int - var numNodes []int + var numFolders []int + var numDevices []int var totFiles []int var maxFiles []int var totMiB []int @@ -256,41 +259,48 @@ func getReport(key string) map[string]interface{} { var sha256Perf []float64 var memorySize []int - for _, fn := range files { - f, err := os.Open(fn) - if err != nil { - continue - } + rows, err := db.Query(`SELECT * FROM Reports WHERE Received > now() - '1 day'::INTERVAL`) + if err != nil { + log.Println("sql:", err) + return nil + } + defer rows.Close() + + for rows.Next() { var rep report - err = json.NewDecoder(f).Decode(&rep) + err := rows.Scan(&rep.Received, &rep.UniqueID, &rep.Version, + &rep.LongVersion, &rep.Platform, &rep.NumFolders, &rep.NumDevices, + &rep.TotFiles, &rep.FolderMaxFiles, &rep.TotMiB, &rep.FolderMaxMiB, + &rep.MemoryUsageMiB, &rep.SHA256Perf, &rep.MemorySize, &rep.Date) + if err != nil { - continue + log.Println("sql:", err) + return nil } - f.Close() nodes++ versions = append(versions, transformVersion(rep.Version)) platforms = append(platforms, rep.Platform) ps := strings.Split(rep.Platform, "-") oses = append(oses, ps[0]) - if rep.NumRepos > 0 { - numRepos = append(numRepos, rep.NumRepos) + if rep.NumFolders > 0 { + numFolders = append(numFolders, rep.NumFolders) } - if rep.NumNodes > 0 { - numNodes = append(numNodes, rep.NumNodes) + if rep.NumDevices > 0 { + numDevices = append(numDevices, rep.NumDevices) } if rep.TotFiles > 0 { totFiles = append(totFiles, rep.TotFiles) } - if rep.RepoMaxFiles > 0 { - maxFiles = append(maxFiles, rep.RepoMaxFiles) + if rep.FolderMaxFiles > 0 { + maxFiles = append(maxFiles, rep.FolderMaxFiles) } if rep.TotMiB > 0 { totMiB = append(totMiB, rep.TotMiB*(1<<20)) } - if rep.RepoMaxMiB > 0 { - maxMiB = append(maxMiB, rep.RepoMaxMiB*(1<<20)) + if rep.FolderMaxMiB > 0 { + maxMiB = append(maxMiB, rep.FolderMaxMiB*(1<<20)) } if rep.MemoryUsageMiB > 0 { memoryUsage = append(memoryUsage, rep.MemoryUsageMiB*(1<<20)) @@ -329,12 +339,12 @@ func getReport(key string) map[string]interface{} { }) categories = append(categories, category{ - Values: statsForInts(numNodes), + Values: statsForInts(numDevices), Descr: "Number of Devices in Cluster", }) categories = append(categories, category{ - Values: statsForInts(numRepos), + Values: statsForInts(numFolders), Descr: "Number of Folders Configured", })