refactor gather

This commit is contained in:
crumbtoo
2024-03-06 17:46:35 -07:00
parent fe44fbfc77
commit f6035b8a6a
4 changed files with 133 additions and 31 deletions

View File

@@ -35,6 +35,7 @@ library
, Rlp.AltSyntax , Rlp.AltSyntax
, Rlp.AltParse , Rlp.AltParse
, Rlp.HindleyMilner , Rlp.HindleyMilner
, Rlp.HindleyMilner.Types
, Rlp.Syntax.Backstage , Rlp.Syntax.Backstage
, Rlp.Syntax.Types , Rlp.Syntax.Types
-- , Rlp.Parse.Decls -- , Rlp.Parse.Decls

View File

@@ -5,6 +5,7 @@ module Rlp.AltSyntax
Program(..), Decl(..), ExprF(..), Pat(..) Program(..), Decl(..), ExprF(..), Pat(..)
, RlpExprF, RlpExpr, Binding(..), Alter(..) , RlpExprF, RlpExpr, Binding(..), Alter(..)
, DataCon(..), Type(..) , DataCon(..), Type(..)
, pattern IntT
, Core.Name, PsName , Core.Name, PsName
, pattern (Core.:->) , pattern (Core.:->)
@@ -25,6 +26,7 @@ import Data.Function (fix)
import GHC.Generics (Generic, Generic1) import GHC.Generics (Generic, Generic1)
import Data.Hashable import Data.Hashable
import Data.Hashable.Lifted import Data.Hashable.Lifted
import GHC.Exts (IsString)
import Control.Lens import Control.Lens
import Text.Show.Deriving import Text.Show.Deriving
@@ -57,10 +59,14 @@ data Type b = VarT b
| ConT b | ConT b
| AppT (Type b) (Type b) | AppT (Type b) (Type b)
| FunT | FunT
| ForallT b (Type b)
deriving (Show, Eq, Generic) deriving (Show, Eq, Generic)
instance (Hashable b) => Hashable (Type b) instance (Hashable b) => Hashable (Type b)
pattern IntT :: (IsString b, Eq b) => Type b
pattern IntT = ConT "Int#"
instance Core.HasArrowSyntax (Type b) (Type b) (Type b) where instance Core.HasArrowSyntax (Type b) (Type b) (Type b) where
_arrowSyntax = prism make unmake where _arrowSyntax = prism make unmake where
make (s,t) = FunT `AppT` s `AppT` t make (s,t) = FunT `AppT` s `AppT` t

View File

@@ -12,6 +12,7 @@ module Rlp.HindleyMilner
import Control.Lens hiding (Context', Context, (:<)) import Control.Lens hiding (Context', Context, (:<))
import Control.Monad.Errorful import Control.Monad.Errorful
import Control.Monad.State import Control.Monad.State
import Control.Monad.Writer.Strict
import Data.Text qualified as T import Data.Text qualified as T
import Data.Pretty import Data.Pretty
import Text.Printf import Text.Printf
@@ -20,6 +21,7 @@ import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as H import Data.HashMap.Strict qualified as H
import Data.HashSet (HashSet) import Data.HashSet (HashSet)
import Data.HashSet qualified as S import Data.HashSet qualified as S
import Data.Maybe (fromMaybe)
import GHC.Generics (Generic(..), Generically(..)) import GHC.Generics (Generic(..), Generically(..))
import Data.Functor import Data.Functor
@@ -29,7 +31,8 @@ import Control.Comonad.Cofree
import Compiler.RlpcError import Compiler.RlpcError
import Rlp.AltSyntax as Rlp import Rlp.AltSyntax as Rlp
import Core.Syntax qualified as Core import Core.Syntax qualified as Core
import Core.Syntax (ExprF(..)) import Core.Syntax (ExprF(..), Lit(..))
import Rlp.HindleyMilner.Types
-------------------------------------------------------------------------------- --------------------------------------------------------------------------------
-- | Type error enum. -- | Type error enum.
@@ -67,43 +70,53 @@ type HMError = Errorful TypeError
infer = undefined infer = undefined
check = undefined check = undefined
type Context' = HashMap PsName (Type PsName)
data Constraint = Equality (Type PsName) (Type PsName)
deriving (Eq, Generic, Show)
instance Hashable Constraint
type Constraints = HashSet Constraint
data PartialJudgement = PartialJudgement Constraints Context'
deriving (Generic, Show)
deriving (Semigroup, Monoid)
via Generically PartialJudgement
constraints :: Lens' PartialJudgement Constraints
constraints = lens sa sbt where
sa (PartialJudgement cs _) = cs
sbt (PartialJudgement _ g) cs' = PartialJudgement cs' g
assumptions :: Lens' PartialJudgement Context'
assumptions = lens sa sbt where
sa (PartialJudgement _ g) = g
sbt (PartialJudgement cs _) g' = PartialJudgement cs g'
fixCofree :: (Functor f, Functor g) fixCofree :: (Functor f, Functor g)
=> Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b) => Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b)
fixCofree = iso sa bt where fixCofree = iso sa bt where
sa = foldFix (() :<) sa = foldFix (() :<)
bt (_ :< as) = Fix $ bt <$> as bt (_ :< as) = Fix $ bt <$> as
data TypeState t m = TypeState type Gather t = WriterT PartialJudgement (HM t)
{ _tsUnique :: Int
, _tsMemo :: HashMap t m
}
deriving Show
makeLenses ''TypeState addConstraint :: Constraint -> Gather t ()
addConstraint = tell . ($ mempty) . (_PartialJudgement .~) . S.singleton
lookupContext :: Applicative m => PsName -> Context' -> m (Type PsName)
lookupContext n g = maybe (error "undefined variable") pure $
H.lookup n g
-- | 'gather', but memoise the result. All recursive calls should be to
-- 'gather'', not 'gather'!
gather' :: Context'
-> Fix (RlpExprF PsName)
-> Gather (Fix (RlpExprF PsName)) (Type PsName)
gather' g e = do
t <- listen $ gather g e
lift . tell $ H.singleton e t
pure (t ^. _1)
gather :: Context'
-> Fix (RlpExprF PsName)
-> Gather (Fix (RlpExprF PsName)) (Type PsName)
gather g (Finl (LitF (IntL _))) = pure IntT
gather g (Finl (VarF n)) = lookupContext n g
gather g (Finl (AppF f x)) = do
tv <- lift freshTv
tf <- gather' g f
tx <- gather' g x
addConstraint $ Equality tf (tx :-> tv)
pure tv
demoContext :: Context'
demoContext = H.fromList
[ ("id", ForallT "a" $ VarT "a" :-> VarT "a")
]
{--
type TC t = State (TypeState t (Type PsName, PartialJudgement)) type TC t = State (TypeState t (Type PsName, PartialJudgement))
(Type PsName, PartialJudgement) (Type PsName, PartialJudgement)
@@ -140,3 +153,5 @@ gather (Fix (InL (Core.AppF f x))) = do
let j'' = mempty & constraints .~ S.singleton (Equality tf $ tx :-> tv) let j'' = mempty & constraints .~ S.singleton (Equality tf $ tx :-> tv)
pure (tv, j <> j' <> j'') pure (tv, j <> j' <> j'')
--}

View File

@@ -0,0 +1,80 @@
{-# LANGUAGE TemplateHaskell #-}
module Rlp.HindleyMilner.Types
where
--------------------------------------------------------------------------------
import Data.Hashable
import Data.HashMap.Strict (HashMap)
import Data.HashMap.Strict qualified as H
import Data.HashSet (HashSet)
import Data.HashSet qualified as S
import GHC.Generics (Generic(..), Generically(..))
import Data.Kind qualified
import Data.Text qualified as T
import Control.Monad.Writer
import Control.Monad.State
import Control.Lens hiding (Context', Context)
import Rlp.AltSyntax
--------------------------------------------------------------------------------
type Context' = HashMap PsName (Type PsName)
data Constraint = Equality (Type PsName) (Type PsName)
deriving (Eq, Generic, Show)
newtype PartialJudgement = PartialJudgement Constraints
deriving (Generic, Show)
deriving (Semigroup, Monoid)
via Generically PartialJudgement
instance Hashable Constraint
type Constraints = HashSet Constraint
type Memo t = HashMap t (Type PsName, PartialJudgement)
newtype HM t a = HM { unHM :: Int -> Memo t -> (a, Int, Memo t) }
runHM :: (Hashable t) => HM t a -> (a, Memo t)
runHM hm = let (a,_,m) = unHM hm 0 mempty in (a,m)
instance Functor (HM t) where
fmap f (HM h) = HM \n m -> h n m & _1 %~ f
instance Applicative (HM t) where
pure a = HM \n m -> (a,n,m)
HM hf <*> HM ha = HM \n m ->
let (f',n',m') = hf n m
(a,n'',m'') = ha n' m'
in (f' a, n'', m'')
instance Monad (HM t) where
HM ha >>= k = HM \n m ->
let (a,n',m') = ha n m
(a',n'',m'') = unHM (k a) n' m'
in (a',n'', m'')
instance Hashable t => MonadWriter (Memo t) (HM t) where
-- IMPORTAN! (<>) is left-biased for HashMap! append `w` to the RIGHt!
writer (a,w) = HM \n m -> (a,n,m <> w)
listen ma = HM \n m ->
let (a,n',m') = unHM ma n m
in ((a,m'),n',m')
pass maww = HM \n m ->
let ((a,ww),n',m') = unHM maww n m
in (a,n',ww m')
instance MonadState Int (HM t) where
state f = HM \n m ->
let (a,n') = f n
in (a,n',m)
freshTv :: HM t (Type PsName)
freshTv = do
n <- get
modify succ
pure . VarT $ "$a" <> T.pack (show n)
makePrisms ''PartialJudgement