From 28ed317147b989864b89868eac2886dfcdfe25a3 Mon Sep 17 00:00:00 2001 From: crumbtoo Date: Wed, 6 Mar 2024 17:46:35 -0700 Subject: [PATCH] refactor gather --- rlp.cabal | 1 + src/Rlp/AltSyntax.hs | 6 +++ src/Rlp/HindleyMilner.hs | 77 +++++++++++++++++++------------- src/Rlp/HindleyMilner/Types.hs | 80 ++++++++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 31 deletions(-) create mode 100644 src/Rlp/HindleyMilner/Types.hs diff --git a/rlp.cabal b/rlp.cabal index dcac32c..0a2f2ee 100644 --- a/rlp.cabal +++ b/rlp.cabal @@ -35,6 +35,7 @@ library , Rlp.AltSyntax , Rlp.AltParse , Rlp.HindleyMilner + , Rlp.HindleyMilner.Types , Rlp.Syntax.Backstage , Rlp.Syntax.Types -- , Rlp.Parse.Decls diff --git a/src/Rlp/AltSyntax.hs b/src/Rlp/AltSyntax.hs index cb2a39b..7254186 100644 --- a/src/Rlp/AltSyntax.hs +++ b/src/Rlp/AltSyntax.hs @@ -5,6 +5,7 @@ module Rlp.AltSyntax Program(..), Decl(..), ExprF(..), Pat(..) , RlpExprF, RlpExpr, Binding(..), Alter(..) , DataCon(..), Type(..) + , pattern IntT , Core.Name, PsName , pattern (Core.:->) @@ -25,6 +26,7 @@ import Data.Function (fix) import GHC.Generics (Generic, Generic1) import Data.Hashable import Data.Hashable.Lifted +import GHC.Exts (IsString) import Control.Lens import Text.Show.Deriving @@ -57,10 +59,14 @@ data Type b = VarT b | ConT b | AppT (Type b) (Type b) | FunT + | ForallT b (Type b) deriving (Show, Eq, Generic) instance (Hashable b) => Hashable (Type b) +pattern IntT :: (IsString b, Eq b) => Type b +pattern IntT = ConT "Int#" + instance Core.HasArrowSyntax (Type b) (Type b) (Type b) where _arrowSyntax = prism make unmake where make (s,t) = FunT `AppT` s `AppT` t diff --git a/src/Rlp/HindleyMilner.hs b/src/Rlp/HindleyMilner.hs index 05f9677..dceb6d5 100644 --- a/src/Rlp/HindleyMilner.hs +++ b/src/Rlp/HindleyMilner.hs @@ -12,6 +12,7 @@ module Rlp.HindleyMilner import Control.Lens hiding (Context', Context, (:<)) import Control.Monad.Errorful import Control.Monad.State +import Control.Monad.Writer.Strict import Data.Text qualified as T import Data.Pretty import Text.Printf @@ -20,6 +21,7 @@ import Data.HashMap.Strict (HashMap) import Data.HashMap.Strict qualified as H import Data.HashSet (HashSet) import Data.HashSet qualified as S +import Data.Maybe (fromMaybe) import GHC.Generics (Generic(..), Generically(..)) import Data.Functor @@ -29,7 +31,8 @@ import Control.Comonad.Cofree import Compiler.RlpcError import Rlp.AltSyntax as Rlp import Core.Syntax qualified as Core -import Core.Syntax (ExprF(..)) +import Core.Syntax (ExprF(..), Lit(..)) +import Rlp.HindleyMilner.Types -------------------------------------------------------------------------------- -- | Type error enum. @@ -67,43 +70,53 @@ type HMError = Errorful TypeError infer = undefined check = undefined -type Context' = HashMap PsName (Type PsName) - -data Constraint = Equality (Type PsName) (Type PsName) - deriving (Eq, Generic, Show) - -instance Hashable Constraint - -type Constraints = HashSet Constraint - -data PartialJudgement = PartialJudgement Constraints Context' - deriving (Generic, Show) - deriving (Semigroup, Monoid) - via Generically PartialJudgement - -constraints :: Lens' PartialJudgement Constraints -constraints = lens sa sbt where - sa (PartialJudgement cs _) = cs - sbt (PartialJudgement _ g) cs' = PartialJudgement cs' g - -assumptions :: Lens' PartialJudgement Context' -assumptions = lens sa sbt where - sa (PartialJudgement _ g) = g - sbt (PartialJudgement cs _) g' = PartialJudgement cs g' - fixCofree :: (Functor f, Functor g) => Iso (Fix f) (Fix g) (Cofree f ()) (Cofree g b) fixCofree = iso sa bt where sa = foldFix (() :<) bt (_ :< as) = Fix $ bt <$> as -data TypeState t m = TypeState - { _tsUnique :: Int - , _tsMemo :: HashMap t m - } - deriving Show +type Gather t = WriterT PartialJudgement (HM t) -makeLenses ''TypeState +addConstraint :: Constraint -> Gather t () +addConstraint = tell . ($ mempty) . (_PartialJudgement .~) . S.singleton + +lookupContext :: Applicative m => PsName -> Context' -> m (Type PsName) +lookupContext n g = maybe (error "undefined variable") pure $ + H.lookup n g + +-- | 'gather', but memoise the result. All recursive calls should be to +-- 'gather'', not 'gather'! + +gather' :: Context' + -> Fix (RlpExprF PsName) + -> Gather (Fix (RlpExprF PsName)) (Type PsName) +gather' g e = do + t <- listen $ gather g e + lift . tell $ H.singleton e t + pure (t ^. _1) + +gather :: Context' + -> Fix (RlpExprF PsName) + -> Gather (Fix (RlpExprF PsName)) (Type PsName) + +gather g (Finl (LitF (IntL _))) = pure IntT + +gather g (Finl (VarF n)) = lookupContext n g + +gather g (Finl (AppF f x)) = do + tv <- lift freshTv + tf <- gather' g f + tx <- gather' g x + addConstraint $ Equality tf (tx :-> tv) + pure tv + +demoContext :: Context' +demoContext = H.fromList + [ ("id", ForallT "a" $ VarT "a" :-> VarT "a") + ] + +{-- type TC t = State (TypeState t (Type PsName, PartialJudgement)) (Type PsName, PartialJudgement) @@ -140,3 +153,5 @@ gather (Fix (InL (Core.AppF f x))) = do let j'' = mempty & constraints .~ S.singleton (Equality tf $ tx :-> tv) pure (tv, j <> j' <> j'') +--} + diff --git a/src/Rlp/HindleyMilner/Types.hs b/src/Rlp/HindleyMilner/Types.hs new file mode 100644 index 0000000..daec2ba --- /dev/null +++ b/src/Rlp/HindleyMilner/Types.hs @@ -0,0 +1,80 @@ +{-# LANGUAGE TemplateHaskell #-} +module Rlp.HindleyMilner.Types + where +-------------------------------------------------------------------------------- +import Data.Hashable +import Data.HashMap.Strict (HashMap) +import Data.HashMap.Strict qualified as H +import Data.HashSet (HashSet) +import Data.HashSet qualified as S +import GHC.Generics (Generic(..), Generically(..)) +import Data.Kind qualified +import Data.Text qualified as T +import Control.Monad.Writer +import Control.Monad.State + +import Control.Lens hiding (Context', Context) + +import Rlp.AltSyntax +-------------------------------------------------------------------------------- + +type Context' = HashMap PsName (Type PsName) + +data Constraint = Equality (Type PsName) (Type PsName) + deriving (Eq, Generic, Show) + +newtype PartialJudgement = PartialJudgement Constraints + deriving (Generic, Show) + deriving (Semigroup, Monoid) + via Generically PartialJudgement + +instance Hashable Constraint + +type Constraints = HashSet Constraint + +type Memo t = HashMap t (Type PsName, PartialJudgement) + +newtype HM t a = HM { unHM :: Int -> Memo t -> (a, Int, Memo t) } + +runHM :: (Hashable t) => HM t a -> (a, Memo t) +runHM hm = let (a,_,m) = unHM hm 0 mempty in (a,m) + +instance Functor (HM t) where + fmap f (HM h) = HM \n m -> h n m & _1 %~ f + +instance Applicative (HM t) where + pure a = HM \n m -> (a,n,m) + HM hf <*> HM ha = HM \n m -> + let (f',n',m') = hf n m + (a,n'',m'') = ha n' m' + in (f' a, n'', m'') + +instance Monad (HM t) where + HM ha >>= k = HM \n m -> + let (a,n',m') = ha n m + (a',n'',m'') = unHM (k a) n' m' + in (a',n'', m'') + +instance Hashable t => MonadWriter (Memo t) (HM t) where + -- IMPORTAN! (<>) is left-biased for HashMap! append `w` to the RIGHt! + writer (a,w) = HM \n m -> (a,n,m <> w) + listen ma = HM \n m -> + let (a,n',m') = unHM ma n m + in ((a,m'),n',m') + pass maww = HM \n m -> + let ((a,ww),n',m') = unHM maww n m + in (a,n',ww m') + +instance MonadState Int (HM t) where + state f = HM \n m -> + let (a,n') = f n + in (a,n',m) + +freshTv :: HM t (Type PsName) +freshTv = do + n <- get + modify succ + pure . VarT $ "$a" <> T.pack (show n) + +makePrisms ''PartialJudgement +