Go openstack

実例に学ぶGoをテスタブルに書く基本

Go openstack

 技術部プラットフォームグループ SRE の akichan です。

 ペパボでは Nyah と呼ばれる OpenStack のプライベートクラウドを運用しており、Load Balancer as a Service(LBaaS) の Octavia が利用可能です。

 先日、このLBaaSに対する不正なアクセスからシステムを防御するために、特定のIPアドレス帯からの通信をブロックするソフトウエアをGoで実装しました。その際に、社内のGoの有識者にレビューしてもらいながら、どのようにリファクタリングを行なっていったかを通して、私と同じようなGoの初学者が押さえておくと良さそうなポイントについてお伝えできればと思います。

Amphora Protector

 今回開発した Amphora Protector について簡単に解説します。

 Octavia の LoadBalancer の実態は、HAProxy がインストールされた Amphora と呼ばれるAct/Stb構成のサーバーです。OpenStack の API を通して様々な設定を変更できますが、アクセス制限については OpenStack の Security Group をアタッチするというのが基本的な手段となります。しかし、Security Group はホワイトリスト形式で、許可する通信を定義する方式のため、ブラックリスト形式のように特定の通信をブロックするというようなルール設定が行えないので、例えば大量のリクエストを送信してくる特定のクライアントからのアクセスを拒否したいというような用途には不向きです。

 これを解決するために作成したのが Amphora Protector です。 Amphora Protector はサーバーとエージェントの2プロセスで構成されます。 サーバーはSlack、WEB APIのインターフェースを介して、LB単位でアクセス制限したい通信元IPアドレス帯の追加、削除を行います。一方、エージェントは各 Amphora で動作しており、定期的にサーバーにAPIアクセスし、自身に紐づくブロック対象のアドレス帯を取得し、登録があれば、その情報をもとに自身が動作するAmphoraaにiptablesでDROPするルールを設定し、不正な通信をブロックしています。

amphora protector イメージ

 アーキテクチャの設計とサーバーの実装は @pyama86 によるもので、私が担当したのはエージェントの実装です。

コードをテストしやすくするために

テストしにくいコード

 Amphora Protector Agent の初期の実装は次のようなものでした。

 このコードはrun関数からhttp.Clientを利用したAPI通信、およびiptableを更新するupdate関数を利用しているため、run関数の責務が多くなり、テストが実装しづらい状態でした。具体的にはrun関数のテストのためにamphora protector serverのスタブサーバーを立てる必要がある上、iptablesコマンドが実行できる環境を用意する必要があります。 このような外部環境に依存するテストは、テストケースの増加に比例し、テストのスループットが低下し、開発の生産性を落とすことになると考えました。

func main() {
  cfg := config{}

  for {
    select {
    case <-ticker.C:
      // 一定間隔で実行され、ブロック定義を追加する
      run(&cfg)
    case <-ctx.Done():
      log.Info(ctx.Err())
      return
    }
  }
}

// Amphora protectorサーバーからフィルタするIPを取得し、
// iptable同期を実行
func run(cfg *config) error {
  // httpリクエスト
  client := &http.Client{Timeout: time.Duration(10) * time.Second}
  req, _ := http.NewRequest("GET", cfg.URL, nil)
  body, _ := ioutil.ReadAll(res.Body)

  var filterIPs []string
  if err := json.Unmarshal(body, &filterIPs); err != nil {
     return err
  }

  // ☆runの中にupdate関数があるため、テストが書きづらい
  update(filterIPs)
  return nil
}

func update(ips []string) {
  // サーバから得た情報をもとにiptablesを設定
}

テストしやすいコードへのリファクタリング

 前述の問題を改善するため、外部依存を注入できるようにインターフェース、構造体を用いて再設計しました。

 まず、アクセスブロックに利用する情報を取得する、Fetcherインターフェースを定義し、これを実装するAMPSrvClient構造体を定義しました。こうすることで後述の Runner 構造体のメンバーとしてAMPSrvClientだけでなく、Fetcher インターフェースを持つテスト用のオブジェクトを埋め込むことができるようになり、テストが書きやすくなるためです。

// fetcher.go

type Fetcher interface {
  Fetch(hostname string) ([]string, error)
}

type AMPSrvClient struct {
  config *Config
}

func (a *AMPSrvClient) Fetch(hostname string) ([]string, error) {
  // Amphora Protector ServerからフィルタするIPアドレスのリストを取得

  client := &http.Client{Timeout: time.Duration(a.config.Timeout) * time.Second}
  u, _ := url.Parse(a.config.Endpoint)
  req, _ := http.NewRequest("GET", u.String(), nil)
  res, _ := client.Do(req)

  return filterIPs, nil
}

次にRunner 構造体を定義します。Runner構造体はメンバーにFetcherインターフェースを持つfetcherIPTablesインターフェースを満たすiptメンバーを持ちます。メソッドとしてConverge メソッドを定義しており、Converge メソッドはfetcher から取得した情報を元に、ipt を利用して、iptablesの状態を収束するメソッドです。

// runner.go

type Runner struct {
  fetcher Fetcher
  ipt IPTables
}

func (r *Runner) Converge(hostname string) error {
  // フィルタするIPを取得
  filterIPs, err := r.fetcher.Fetch(hostname)
  if err != nil {
    return err
  }

  // iptables更新処理
  if err := r.ipt.createChainIfNotExist(); err != nil {
  	return err
  }

  if err := r.ipt.addRule(filterIPs); err != nil {
  	return err
  }

  if err := r.ipt.removeRule(filterIPs); err != nil {
  	return err
  }
  return nil
}

 このようにinterfaceを利用することで例えばテスト時には、下記の構造体をfetcher に埋め込むことも可能です。

type TestFetcher struct {
...
}

func (t *TestFetcher) Fetch(hostname string) ([]string, error) {
  // 引数をそのまま返す
  return []string{hostname}, nil
}

 ipt についても同じような実装を行い、最終的にmain関数ではRunner構造体に必要な構造体を渡して初期化し、Convergeメソッドを呼び出すだけになりました。

// main.go
func main() {
  c, _ := NewConfig(*path)
  // Fetcherインターフェースを持つ構造体
  f := NewFetcher(c)
  // IPTablesインターフェースを持つ構造体
  i, _ := iptables.New()
  // それぞれを引数にRunner構造体を初期化
  runner := NewRunner(f, i)

  for {
    select {
      case <-ticker.C:
        err = runner.Converge(ownHostname)
        // 省略
      case <-ctx.Done():
        log.Info(ctx.Err())
        return
    }
  }
}

テーブルドリブンテストのテンプレートを簡単に生成する

Goではあるテスト対象に対して複数の条件のテストを実行したい場合、テーブルドリブンテストという手法がよく用いられます。

次の go.dev のサンプル では、ReveerseRunes関数に対してHello, worldHello, 世界の二つの場合のテストを行なっています。 冗長なコードが抑えられ、入力と期待する結果がわかりやすいという利点があります。

// https://go.dev/doc/code#Testing
package morestrings

import "testing"

func TestReverseRunes(t *testing.T) {
  cases := []struct {
    in, want string
  }{
    {"Hello, world", "dlrow ,olleH"},
    {"Hello, 世界", "界世 ,olleH"},
    {"", ""},
  }
  for _, c := range cases {
    got := ReverseRunes(c.in)
    if got != c.want {
      t.Errorf("ReverseRunes(%q) == %q, want %q", c.in, got, c.want)
    }
  }
}

単純なテストと比べるとやや複雑な書き方にはなりますが、テーブルドリブンテストを簡単に生成する cweill/gotests を使うと簡単に書けます。

例としてRunnerのメソッドのテーブルドリブンテストを生成してみます。 次のようなコマンドを実行します。

# Runnerのメソッドのテストを自動生成
$ gotests -w -only Runner

すると、シグネチャを考慮した次のようなテストのテンプレートを作成してくれます。

func TestRunner_Converge(t *testing.T) {
       type args struct {
               hostname string
       }
       tests := []struct {
               name    string
               r       *Runner
               args    args
               wantErr bool
       }{
               // TODO: Add test cases.
       }
       for _, tt := range tests {
               t.Run(tt.name, func(t *testing.T) {
                       if err := tt.r.Converge(tt.args.hostname); (err != nil) != tt.wantErr {
                               t.Errorf("Runner.Converge() error = %v, wantErr %v", err, tt.wantErr)
                       }
               })
       }
}

モックを簡単に生成する

gomockを用いることで、インターフェース定義から簡単にモックの生成を行うことができます。

例としてIPtablesインターフェース定義からモックを作成するには次のようにします。

$ mockgen -source=./iptables.go -destination=./iptables_mock.go IPTables

指定したパスにMockが出力されます。 MockIPTablesはIPTablesインターフェースを満たす構造体で、NewMockIPTablesで生成してテスト対象に注入します。 MockIPTablesMockRecorderMockIPTablesのMockの呼び出しを記録するもので、テストの際に直接使うことはありません。

// MockIPTables is a mock of IPTables interface.
type MockIPTables struct {
  ctrl     *gomock.Controller
  recorder *MockIPTablesMockRecorder
}

// MockIPTablesMockRecorder is the mock recorder for MockIPTables.
type MockIPTablesMockRecorder struct {
  mock *MockIPTables
}

// NewMockIPTables creates a new mock instance.
func NewMockIPTables(ctrl *gomock.Controller) *MockIPTables {
  mock := &MockIPTables{ctrl: ctrl}
  mock.recorder = &MockIPTablesMockRecorder{mock}
  return mock
}

// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockIPTables) EXPECT() *MockIPTablesMockRecorder {
  // 省略
}

// Append mocks base method.
func (m *MockIPTables) Append(table, chain string, rulespec ...string) error {
  // 省略
}

// Append indicates an expected call of Append.
func (mr *MockIPTablesMockRecorder) Append(table, chain interface{}, rulespec ...interface{}) *gomock.Call {
  // 省略
}

Goにおけるモックは前述のテーブルドリブンテストと合わせて次のように使えます。

func TestRunner_addRule(t *testing.T) {
  tests := []struct {
    name      string
    filterIps []string
    beforeDo  func() (IPTables, *gomock.Controller)
    wantErr   bool
  }{
    {
      name:      "ok",
      filterIps: []string{"192.0.2.0/24"},
      beforeDo: func() (IPTables, *gomock.Controller) {
        // 1. モックコントローラーの生成
        controller := gomock.NewController(t)
        // 2. IPTablesモックの生成
        ipt := NewMockIPTables(controller)

        // 3. モックに対して呼ばれる関数と引数、返り値を指定
        // `Exists(table, chain, "-s", "192.0.2.0/24", "-j", "DROP")`が1回呼ばれ、
        // `false, nil`を返すことを期待する
        ipt.EXPECT().Exists(table, chain, "-s", "192.0.2.0/24", "-j", "DROP").Return(
          false,
          nil,
        ).Times(1)
        // `Append(table, chain, "-s", "192.0.2.0/24", "-j", "DROP")`が1回呼ばれ、
        // `nil`を返すことを期待する
        ipt.EXPECT().Append(table, chain, "-s", "192.0.2.0/24", "-j", "DROP").
          Return(nil).Times(1)

        return ipt, controller
      },
      wantErr: false,
    },
  }
  for _, tt := range tests {
    t.Run(tt.name, func(t *testing.T) {
      // IPTablesのモックを Runner に注入
      ipt, controller := tt.beforeDo()
      defer controller.Finish()
      r := Runner{ipt: ipt}

      if err := r.addRule(tt.filterIps); (err != nil) != tt.wantErr {
        t.Errorf("addRule() error = %v, wantErr %v", err, tt.wantErr)
      }
    })
  }
}

まとめ

以上のリファクタリング、テスト追加の結果、開発中どんどん生産性が向上していくことを感じることができました。 またテストを意識した結果、初期実装に比べメソッドの責務が適度に分割されすっきりとしたコードになりました。

今回の経験を通してテスタブルな設計や効率的なテストコードの作成を学べたので、今後の業務のコードを書くシーンで活用していきたいと思います。