From ace6e6364ceb6f7ba46d2c2bb3f4955ff41f63e7 Mon Sep 17 00:00:00 2001 From: Arthur Barr Date: Tue, 20 Feb 2018 14:35:58 +0000 Subject: [PATCH] Use context and waitgroup for mirroring --- cmd/runmqserver/logging.go | 111 +++++++++++++++++++++++++++++++++ cmd/runmqserver/main.go | 99 +++++------------------------ cmd/runmqserver/mirror.go | 73 +++++++++++++--------- cmd/runmqserver/mirror_test.go | 53 ++++++++++++---- 4 files changed, 208 insertions(+), 128 deletions(-) create mode 100644 cmd/runmqserver/logging.go diff --git a/cmd/runmqserver/logging.go b/cmd/runmqserver/logging.go new file mode 100644 index 0000000..b49f3e7 --- /dev/null +++ b/cmd/runmqserver/logging.go @@ -0,0 +1,111 @@ +/* +© Copyright IBM Corporation 2017, 2018 + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package main + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sync" + + "github.com/ibm-messaging/mq-container/internal/mqini" + "github.com/sirupsen/logrus" + log "github.com/sirupsen/logrus" +) + +var debug = false + +// timestampFormat matches the format used by MQ messages (includes milliseconds) +const timestampFormat string = "2006-01-02T15:04:05.000Z07:00" + +type simpleTextFormatter struct { +} + +func (f *simpleTextFormatter) Format(entry *logrus.Entry) ([]byte, error) { + // If debugging, and a prefix, but only for this formatter. + if entry.Level == logrus.DebugLevel { + entry.Message = "DEBUG: " + entry.Message + } + // Use a simple format, with a timestamp + return []byte(formatSimple(entry.Time.Format(timestampFormat), entry.Message)), nil +} + +func logDebug(msg string) { + if debug { + log.Debugln(msg) + } +} + +func logDebugf(format string, args ...interface{}) { + if debug { + log.Debugf(format, args...) + } +} + +func jsonLogs() bool { + e := os.Getenv("MQ_ALPHA_JSON_LOGS") + if e == "true" || e == "1" { + return true + } + return false +} + +func mirrorToStdout(msg string) { + fmt.Println(msg) +} + +func formatSimple(datetime string, message string) string { + return fmt.Sprintf("%v %v\n", datetime, message) +} + +func mirrorLogs(ctx context.Context, wg *sync.WaitGroup, name string, fromStart bool) (chan error, error) { + // Always use the JSON log as the source + // Put the queue manager name in quotes to handle cases like name=.. + qm, err := mqini.GetQueueManager(name) + if err != nil { + logDebugf("%v", err) + return nil, err + } + f := filepath.Join(mqini.GetErrorLogDirectory(qm), "AMQERR01.json") + // f := fmt.Sprintf("/var/mqm/qmgrs/\"%v\"/errors/AMQERR01.json", name) + if jsonLogs() { + return mirrorLog(ctx, wg, f, fromStart, mirrorToStdout) + } + return mirrorLog(ctx, wg, f, fromStart, func(msg string) { + // Parse the JSON message, and print a simplified version + var obj map[string]interface{} + json.Unmarshal([]byte(msg), &obj) + fmt.Printf(formatSimple(obj["ibm_datetime"].(string), obj["message"].(string))) + }) +} + +func configureLogger() { + if jsonLogs() { + formatter := logrus.JSONFormatter{ + FieldMap: logrus.FieldMap{ + logrus.FieldKeyMsg: "message", + logrus.FieldKeyLevel: "ibm_level", + logrus.FieldKeyTime: "ibm_datetime", + }, + TimestampFormat: timestampFormat, + } + logrus.SetFormatter(&formatter) + } else { + log.SetFormatter(new(simpleTextFormatter)) + } +} diff --git a/cmd/runmqserver/main.go b/cmd/runmqserver/main.go index 4226628..b3a77f6 100644 --- a/cmd/runmqserver/main.go +++ b/cmd/runmqserver/main.go @@ -18,39 +18,24 @@ limitations under the License. package main import ( - "encoding/json" + "context" "errors" - "fmt" "io" "io/ioutil" "os" "os/exec" "path/filepath" "strings" + "sync" "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus" "github.com/ibm-messaging/mq-container/internal/command" - "github.com/ibm-messaging/mq-container/internal/mqini" "github.com/ibm-messaging/mq-container/internal/name" "github.com/ibm-messaging/mq-container/internal/ready" ) -var debug = false - -func logDebug(msg string) { - if debug { - log.Debug(msg) - } -} - -func logDebugf(format string, args ...interface{}) { - if debug { - log.Debugf(format, args...) - } -} - // createDirStructure creates the default MQ directory structure under /var/mqm func createDirStructure() error { out, _, err := command.Run("/opt/mqm/bin/crtmqdir", "-f", "-s") @@ -156,69 +141,6 @@ func stopQueueManager(name string) error { return nil } -func jsonLogs() bool { - e := os.Getenv("MQ_ALPHA_JSON_LOGS") - if e == "true" || e == "1" { - return true - } - return false -} - -func mirrorLogs(name string, fromStart bool) (chan bool, error) { - // Always use the JSON log as the source - // Put the queue manager name in quotes to handle cases like name=.. - qm, err := mqini.GetQueueManager(name) - if err != nil { - logDebugf("%v", err) - return nil, err - } - f := filepath.Join(mqini.GetErrorLogDirectory(qm), "AMQERR01.json") - // f := fmt.Sprintf("/var/mqm/qmgrs/\"%v\"/errors/AMQERR01.json", name) - if jsonLogs() { - return mirrorLog(f, fromStart, func(msg string) { - // Print the message straight to stdout - fmt.Println(msg) - }) - } - return mirrorLog(f, fromStart, func(msg string) { - // Parse the JSON message, and print a simplified version - var obj map[string]interface{} - json.Unmarshal([]byte(msg), &obj) - fmt.Printf("%v %v\n", obj["ibm_datetime"], obj["message"]) - }) -} - -type simpleTextFormatter struct { -} - -const timestampFormat string = "2006-01-02T15:04:05.000Z07:00" - -func (f *simpleTextFormatter) Format(entry *logrus.Entry) ([]byte, error) { - // If debugging, and a prefix, but only for this formatter. - if entry.Level == logrus.DebugLevel { - entry.Message = "DEBUG: " + entry.Message - } - // Use a simple format, with a timestamp - return []byte(fmt.Sprintf("%v %v\n", entry.Time.Format(timestampFormat), entry.Message)), nil -} - -func configureLogger() { - if jsonLogs() { - formatter := logrus.JSONFormatter{ - FieldMap: logrus.FieldMap{ - logrus.FieldKeyMsg: "message", - logrus.FieldKeyLevel: "ibm_level", - logrus.FieldKeyTime: "ibm_datetime", - }, - // Match time stamp format used by MQ messages (includes milliseconds) - TimestampFormat: timestampFormat, - } - logrus.SetFormatter(&formatter) - } else { - log.SetFormatter(new(simpleTextFormatter)) - } -} - func doMain() error { configureLogger() err := ready.Clear() @@ -262,10 +184,21 @@ func doMain() error { if err != nil { return err } - mirrorLifecycle, err := mirrorLogs(name, newQM) + var wg sync.WaitGroup + ctx, cancelMirror := context.WithCancel(context.Background()) + defer func() { + log.Debugln("Cancel log mirroring") + cancelMirror() + }() + // TODO: Use the error channel + _, err = mirrorLogs(ctx, &wg, name, newQM) if err != nil { return err } + defer func() { + log.Debugln("Waiting for log mirroring to complete") + wg.Wait() + }() err = updateCommandLevel() if err != nil { return err @@ -285,10 +218,6 @@ func doMain() error { ready.Set() // Wait for terminate signal <-signalControl - // Tell the mirroring goroutine to shutdown - mirrorLifecycle <- true - // Wait for the mirroring goroutine to finish cleanly - <-mirrorLifecycle return nil } diff --git a/cmd/runmqserver/mirror.go b/cmd/runmqserver/mirror.go index 613767f..2afa6cb 100644 --- a/cmd/runmqserver/mirror.go +++ b/cmd/runmqserver/mirror.go @@ -17,32 +17,39 @@ package main import ( "bufio" + "context" "fmt" "os" + "sync" "time" log "github.com/sirupsen/logrus" ) // waitForFile waits until the specified file exists -func waitForFile(path string) (os.FileInfo, error) { +func waitForFile(ctx context.Context, path string) (os.FileInfo, error) { var fi os.FileInfo var err error // Wait for file to exist for { - fi, err = os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - time.Sleep(500 * time.Millisecond) - continue - } else { - return nil, err + select { + // Check to see if cancellation has been requested + case <-ctx.Done(): + return nil, nil + default: + fi, err = os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + time.Sleep(500 * time.Millisecond) + continue + } else { + return nil, err + } } + log.Debugf("File exists: %v, %v", path, fi.Size()) + return fi, nil } - break } - log.Debugf("File exists: %v, %v", path, fi.Size()) - return fi, nil } type mirrorFunc func(msg string) @@ -75,8 +82,8 @@ func mirrorAvailableMessages(f *os.File, mf mirrorFunc) { // mirrorLog tails the specified file, and logs each line to stdout. // This is useful for usability, as the container console log can show // messages from the MQ error logs. -func mirrorLog(path string, fromStart bool, mf mirrorFunc) (chan bool, error) { - lifecycle := make(chan bool) +func mirrorLog(ctx context.Context, wg *sync.WaitGroup, path string, fromStart bool, mf mirrorFunc) (chan error, error) { + errorChannel := make(chan error, 1) var offset int64 = -1 var f *os.File var err error @@ -104,19 +111,29 @@ func mirrorLog(path string, fromStart bool, mf mirrorFunc) (chan bool, error) { offset = fi.Size() } + // Increment wait group counter, only if the goroutine gets started + wg.Add(1) go func() { + // Notify the wait group when this goroutine ends + defer func() { + log.Debugf("Finished monitoring %v", path) + wg.Done() + }() if f == nil { // File didn't exist, so need to wait for it - fi, err = waitForFile(path) + fi, err = waitForFile(ctx, path) if err != nil { log.Errorln(err) - lifecycle <- true + errorChannel <- err + return + } + if fi == nil { return } f, err = os.OpenFile(path, os.O_RDONLY, 0) if err != nil { log.Errorln(err) - lifecycle <- true + errorChannel <- err return } } @@ -124,7 +141,7 @@ func mirrorLog(path string, fromStart bool, mf mirrorFunc) (chan bool, error) { fi, err = f.Stat() if err != nil { log.Errorln(err) - lifecycle <- true + errorChannel <- err return } // The file now exists. If it didn't exist before we started, offset=0 @@ -135,18 +152,16 @@ func mirrorLog(path string, fromStart bool, mf mirrorFunc) (chan bool, error) { } closing := false for { - log.Debugln("Start of loop") // If there's already data there, mirror it now. mirrorAvailableMessages(f, mf) - log.Debugf("Stat %v", path) newFI, err := os.Stat(path) if err != nil { log.Error(err) - lifecycle <- true + errorChannel <- err return } if !os.SameFile(fi, newFI) { - log.Debugln("Not the same file!") + log.Debugln("Detected log rotation") // WARNING: There is a possible race condition here. If *another* // log rotation happens before we can open the new file, then we // could skip all those messages. This could happen with a very small @@ -165,21 +180,19 @@ func mirrorLog(path string, fromStart bool, mf mirrorFunc) (chan bool, error) { // Don't seek this time, because we know it's a new file mirrorAvailableMessages(f, mf) } - log.Debugln("Check for lifecycle event") select { - // Have we been asked to shut down? - case <-lifecycle: + case <-ctx.Done(): + log.Debug("Context cancelled") + if closing { + log.Debug("Shutting down mirror") + return + } // Set a flag, to allow one more time through the loop closing = true default: - if closing { - log.Debugln("Shutting down mirror") - lifecycle <- true - return - } time.Sleep(500 * time.Millisecond) } } }() - return lifecycle, nil + return errorChannel, nil } diff --git a/cmd/runmqserver/mirror_test.go b/cmd/runmqserver/mirror_test.go index 5dddc66..2ce8809 100644 --- a/cmd/runmqserver/mirror_test.go +++ b/cmd/runmqserver/mirror_test.go @@ -16,12 +16,15 @@ limitations under the License. package main import ( + "context" "fmt" "io/ioutil" "os" "strconv" "strings" + "sync" "testing" + "time" log "github.com/sirupsen/logrus" ) @@ -38,7 +41,9 @@ func TestMirrorLogWithoutRotation(t *testing.T) { t.Log(tmp.Name()) defer os.Remove(tmp.Name()) count := 0 - lifecycle, err := mirrorLog(tmp.Name(), true, func(msg string) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + _, err = mirrorLog(ctx, &wg, tmp.Name(), true, func(msg string) { count++ }) if err != nil { @@ -53,9 +58,8 @@ func TestMirrorLogWithoutRotation(t *testing.T) { fmt.Fprintln(f, "{\"message\"=\"B\"}") fmt.Fprintln(f, "{\"message\"=\"C\"}") f.Close() - lifecycle <- true - <-lifecycle - + cancel() + wg.Wait() if count != 3 { t.Fatalf("Expected 3 log entries; got %v", count) } @@ -78,7 +82,9 @@ func TestMirrorLogWithRotation(t *testing.T) { os.Remove(tmp.Name()) }() count := 0 - lifecycle, err := mirrorLog(tmp.Name(), true, func(msg string) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + _, err = mirrorLog(ctx, &wg, tmp.Name(), true, func(msg string) { count++ }) if err != nil { @@ -109,9 +115,8 @@ func TestMirrorLogWithRotation(t *testing.T) { f.Close() // Shut the mirroring down - lifecycle <- true - // Wait until it's finished - <-lifecycle + cancel() + wg.Wait() if count != 5 { t.Fatalf("Expected 5 log entries; got %v", count) @@ -130,7 +135,9 @@ func testMirrorLogExistingFile(t *testing.T, newQM bool) int { ioutil.WriteFile(tmp.Name(), []byte("{\"message\"=\"A\"}\n"), 0600) defer os.Remove(tmp.Name()) count := 0 - lifecycle, err := mirrorLog(tmp.Name(), newQM, func(msg string) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + _, err = mirrorLog(ctx, &wg, tmp.Name(), newQM, func(msg string) { count++ }) if err != nil { @@ -144,8 +151,8 @@ func testMirrorLogExistingFile(t *testing.T, newQM bool) int { fmt.Fprintln(f, "{\"message\"=\"B\"}") fmt.Fprintln(f, "{\"message\"=\"C\"}") f.Close() - lifecycle <- true - <-lifecycle + cancel() + wg.Wait() return count } @@ -167,6 +174,26 @@ func TestMirrorLogExistingFileButNewQueueManager(t *testing.T) { } } -func init() { - log.SetLevel(log.DebugLevel) +func TestMirrorLogCancelWhileWaiting(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + defer func() { + cancel() + wg.Wait() + }() + _, err := mirrorLog(ctx, &wg, "fake.log", true, func(msg string) { + }) + if err != nil { + t.Error(err) + } + time.Sleep(time.Second * 3) + cancel() + wg.Wait() + // No need to assert anything. If it didn't work, the code would have hung (TODO: not ideal) +} + +func init() { + fmt.Println("Setting debug level") + log.SetLevel(log.DebugLevel) + log.SetFormatter(new(simpleTextFormatter)) }