読者です 読者をやめる 読者になる 読者になる

新人SEの学習記録

14年度入社SEの学習記録用に始めたブログです。気づけば社会人3年目に突入。

学習記録:Scala関数型デザイン 第6章

第6章:純粋関数型の状態

副作用を使った乱数の生成

本章では,乱数の生成を例に用いて,状態を操作する純粋関数型のプログラムを記述する方法について見ていく。
ここでの目標は,ステートフルAPIを純粋関数型にするための基本的なパターンを示すことである。

Scalaで(擬似)乱数を生成する場合,scala.util.Randomクラスを使用する。

// 現在のシステム時刻をシードとする新しい乱数ジェネレータを生成
scala> val rng = new scala.util.Random
rng: scala.util.Random = scala.util.Random@66048bfd

scala> rng.nextDouble
res0: Double = 0.08116684838606592

scala> rng.nextDouble
res1: Double = 0.2765568131467897

scala> rng.nextInt
res2: Int = -1514756962

// 0〜9のランダムな整数を取得
scala> rng.nextInt(10)
res3: Int = 5

scala> rng.nextInt(10)
res4: Int = 3

scala> rng.nextInt(10)
res5: Int = 7

呼び出される度にオブジェクトrngの内部状態が更新されていることが想像できる。
そうでなければ,nextIntやnextDoubleが呼び出される度に同じ値が返されるはずである。
状態の更新は副作用として実行されるため,これらのメソッドは参照透過ではなく,モジュール性に乏しい。

テスタビリティを例に見ると,ランダム性を利用するメソッドを記述する場合はテストを再現可能にする必要がある。
以下のrollDieメソッドは,6面サイコロを一つ振るシミュレーションであり,1〜6の値を返す。

def rollDie: Int = {
  val rng = new scala.util.Random
  rng.nextInt(6)  // 0〜5の値を返す!
}

このメソッドには,本来返すべき値1~6ではなく,0~5を返すというエラーがある。
しかし,このメソッドを6回テストしたうちの5回は仕様を満たすことになる。
また,テストが実際に失敗した場合には,その失敗を確実に再現できることが理想的である。

乱数ジェネレータを渡すことでこの問題を解決できないだろうか。
これなら,失敗したテストを再現したい場合には,テストが失敗したのと同じジェネレータを渡せば良い。

def rollDie(rng: scala.util.Random): Int = rng.nextInt(6)

しかし,これでもまだ問題は残っている。
同じジェネレータにするには,同じシードで作成するだけでなく,同じ状態にする必要がある。
つまり,ジェネレータ作成後に,メソッドが同じ回数だけ呼び出されていなければならない。
これを保証するのは至難の業である。

では,どうすれば良いか?

純粋関数型の乱数の生成

ここで,副作用を仕様しないという原理に立ち返ってみる。
参照透過性を取り戻すためには,状態の更新を明示的なものにすることが鍵となる。
状態を副作用として更新するのではなく,生成された値とともに新しい状態を返すようにすればよい。

trait RNG {
  def nextint: (Int, RNG)
}

このメソッドは,ランダムな整数を生成するだけでなく,乱数と同時に新しい状態も返し,古い状態はそのままにしておく。
実質的には,次の状態の計算を,プログラムの他の部分に対する新しい状態の通知から切り離すことになる。
これにより,nextIntの呼び出し元が新しい状態の処理を完全に制御できるようになる。

以下に,scala.util.Randomと同じアルゴリズムを使用する純粋関数型のシンプルな乱数ジェネレータを示す。

  case class SimpleRNG(seed: Long) extends RNG {
    def nextInt: (Int, RNG) = {
      // 現在のシードを使って新しいシードを作成
      val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL
      // 新しいシードから作成されたRNGインスタンス = 次の状態
      val nextRNG = Simple(newSeed)
      // 新しい擬似乱数。>>>は0埋め右バイナリシフト
      val n = (newSeed >>> 16).toInt
      // 戻り値は擬似乱数とRNGの次の状態からなるタプル
      (n, nextRNG)
    }
  }

このAPIインタプリタで試してみる。

// 適当なシードとして13を入れる
scala> val rng = RNG.SimpleRNG(13)
rng: RNG.SimpleRNG = SimpleRNG(13)

// rng.nextIntから返されるペアを分解して2つの値を宣言する構文。
// 乱数n1とRNGの次の状態rng2を取得。
scala> val (n1, rng2) = rng.nextInt
n1: Int = 5001735
rng2: RNG = SimpleRNG(327793750932)

// 次の状態であるrng2を使って次の乱数を取得。
scala> val (n2, rng3) = rng2.nextInt
n2: Int = -2132165362
rng3: RNG = SimpleRNG(141741387574799)

// 同じ状態を使って呼び出せば,返り値は同じになる。
scala> val (n3, rng4) = rng2.nextInt
n3: Int = -2132165362
rng4: RNG = SimpleRNG(141741387574799)

上記のステートメントは,何回実行しても常に同じ値を返す。
例えば,rng2.nextIntを呼び出せば常に-2132165362と新しいRNGが返される。
つまり,このAPIは純粋関数である。

ステートフルAPIの純粋化

このステートフルなAPIを純粋関数化するという解決策は,乱数の生成に限ったことではない。
例えば,以下のようなクラスがあったとする。

class Foo {
  private var s: FooState = ...
  def bar: Bar
  def baz: Int
}

barとbazはそれぞれ何らかの方法でsを変化させる。
状態から状態への遷移を明確にすることで,これを純粋関数型のAPIに機械的に変換できる。

class Foo {
  def bar: (Bar, Foo)
  def baz: (Int, Foo)
}

このパターンを利用するとしたら,計算された次の状態をプログラムの他の部分に渡す責任は呼び出し元にある。
先ほどの純粋関数型のRNGの例では,前のRNGを再利用すると同じ値が生成されてしまう。
例えば,乱数のペアを返す関数を作るときは,

def randomPair(rng: RNG): (Int, Int) = {
  val (i1,_) = rng.nextInt
  val (i2,_) = rng.nextInt
  (i1, i2) // 同じ値が返ってきてしまう!
}

上のように書いてしまうと,常にi1とi2が同じ値になってしまう。
2つの異なる数を生成したい場合は,一つ目のnextIntが返すRNGを使って2つ目のnextIntを呼び出す必要がある。

def randomPair(rng: RNG): (Int, Int) = {
  val (i1,rng2) = rng.nextInt
  val (i2,rng3) = rng2.nextInt // rng2を使って生成する
  ((i1, i2), rng3)  // 乱数を2つ生成した後の最終状態を一緒に返すことで,
                        //  呼び出し元が新しい状態を使って乱数をさらに生成できる
}
Exercise 6.1
  • RNG.nextIntを使って0~Int.maxValueのランダムな整数を生成する関数を喜寿tせよ。

RNG.nextIntの絶対値を取れば良い。今回は負の値ならばマイナスを付けることで対応する。
nextIntの返り値がInt.minValueだった場合,対応する自然数が無いことに注意する必要がある。

  def nonNegativeInt(rng: RNG): (Int, RNG) = {
    val (n, rng2) = rng.nextInt
    if (n == Int.MinValue) (0, rng2)
    else if (n < 0) (-n, rng2)
    else (n, rng2)
  }

自分はMinValueのときは0にしてしまったが,
解答では(n + 1)のマイナスを取っていて,こっちの方がすっきり書ける。
また,タプルの()内でif/else文を書くこともできる模様。

  def nonNegativeInt(rng: RNG): (Int, RNG) = {
    val (n, rng2) = rng.nextInt
    // (n, rng2)を返すとき,nだけif/else文で場合分けしている
    (if (n < 0) -(n + 1) else n, rng2)
  }

使うとこんな感じ。

scala> val rng = RNG.SimpleRNG(88)
rng: RNG.SimpleRNG = SimpleRNG(88)

scala> val (n1, rng2) = RNG.nonNegativeInt(rng)
n1: Int = 33857903
rng2: RNG = SimpleRNG(2218911544707)

scala> val (n2, rng3) = RNG.nonNegativeInt(rng2)
n2: Int = 1400804657
rng3: RNG = SimpleRNG(91803134032594)
Exercise 6.2
  • 0~1(1を含まない)のDouble型の値を生成する関数を記述せよ。

nonNegativeIntで生成した乱数をInt.MaxValueで割ればOK。

  def double(rng: RNG): (Double, RNG) =	{
    val (n, rng2) = nonNegativeInt(rng)
    (n.toDouble	/ Int.MaxValue, rng2)
  }

上では(1を含まない)という仕様を入れるのを忘れていた。
解答では,(n / (Int.MaxValue.toDouble + 1), rng2)を返すことで対応していた。

Exercise 6.3
  • (Int, Double), (Double, Int), (Double, Double, Double)を生成する関数をそれぞれ記述せよ。
  // double関数が使える
  def intDouble(rng: RNG): ((Int,Double), RNG) = {
    val (n, rng2) = rng.nextInt
    val (d, rng3) = double(rng2)
    ((n, d), rng2)
  }

  // intDoubleを使って返す順序を入れ替えればOK
  def doubleInt(rng: RNG): ((Double,Int), RNG) = {
    val ((n, d), rng2) = intDouble(rng)
    ((d, n), rng2)
  }

  // doubleを3回呼び出す
  def double3(rng: RNG): ((Double,Double,Double), RNG) = {
    val (d1, rng2) = double(rng)
    val (d2, rng3) = double(rng2)
    val (d3, rng4) = double(rng3)
    ((d1,d2,d3), rng4)
  }

使うとこんな感じ。

scala> val rng = RNG.SimpleRNG(111)
rng: RNG.SimpleRNG = SimpleRNG(111)

scala> val ((n1, d1), rng2) = RNG.intDouble(rng)
n1: Int = 42707127
d1: Double = 0.2768184925787237
rng2: RNG = SimpleRNG(2798854334798)

scala> val ((d2, n2), rng3) = RNG.doubleInt(rng2)
d2: Double = 0.44440602950956953
n2: Int = 594463186
rng3: RNG = SimpleRNG(38958739384897)

scala> val ((d3, d4, d5), rng4) = RNG.double3(rng3)
d3: Double = 0.44440602950956953
d4: Double = 0.5102882396943347
d5: Double = 0.12317692820130705
rng4: RNG = SimpleRNG(17335611505970)
Exercise 6.4
  • ランダムな整数のリストを生成する関数を記述せよ。

関数intsには,リストの長さcountとrngを渡す。
再帰を使い,countが0になるまで乱数を生成してListに追加していく。

  def ints(count: Int)(rng: RNG): (List[Int], RNG) = {
    @annotation.tailrec
    def go(cnt: Int, acc: List[Int], rng: RNG): (List[Int], RNG) = {
      if (cnt <= 0) (acc, rng)
      else {
        val (n, rng2) = rng.nextInt
        go(cnt - 1, n :: acc, rng2)
      }
    }
    go(count, List(), rng)
  }

使うとこんな感じ。

scala> val rng = RNG.SimpleRNG(222)
rng: RNG.SimpleRNG = SimpleRNG(222)

scala> val (l, rng2) = RNG.ints(3)(rng)
l: List[Int] = List(1729905572, 1184694134, 85414255)
rng2: RNG = SimpleRNG(113371091627571)

scala> val (l2, rng3) = RNG.ints(10)(rng2)
l2: List[Int] = List(-2142059753, 172509991, -645757138, 734111366, 1460254146, 833898298, 1830210529, -1476172985, -1036913855, 1432996927)
rng3: RNG = SimpleRNG(141092948798861)

scala> val (l3, rng4) = RNG.ints(-1)(rng3)
l3: List[Int] = List()
rng4: RNG = SimpleRNG(141092948798861)