Skip to content

Relta API Reference

Client

Source code in src/relta/client.py
class Client:
    def __init__(self, config: Optional[Configuration] = None):
        if config is None:
            config = Configuration()
        self.config = config

        logfire.configure(
            send_to_logfire="if-token-present",
            token=self.config.logfire_token if self.config.logfire_token else None,
            console=logfire.ConsoleOptions(
                min_log_level="warning", show_project_link=False
            ),
        )
        logfire.info("Client initialized")

        if not self.config.encryption_key:
            print(
                f"Encryption key not found in .env file, but may autogenerate: {config.auto_generate_encryption_key=}"
            )
            if self.config.auto_generate_encryption_key:
                key = Fernet.generate_key().decode(
                    "utf-8"
                )  # hacky character sets, but it should work
                self.config.encryption_key = key
                with open(".env", "a+") as f:
                    f.write(
                        f"\n# Generated by Relta on {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\nENCRYPTION_KEY={key}\n# End Relta generated key\n"
                    )
                print("Encryption key generated and added to .env file")

        self.engine = create_engine(
            f"duckdb:///{self.config.relta_internal_path}", echo=self.config.debug
        )

        folders = {
            k: v for k, v in self.config.model_dump().items() if k.endswith("_dir_path")
        }
        for fpath in folders:
            if not os.path.exists(folders[fpath]):
                logfire.info(f"Creating directory at {folders[fpath]}")
                os.mkdir(folders[fpath])

        # Populate ClassVar's from Configuration
        DataSource.conn = duckdb.connect(str(self.config.relta_data_path))
        DataSource.set_encryption_key(self.config.encryption_key)

        Chat.agent = SQLAgent(self.config)
        SQLAgent.model = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0)
        SQLAgent.mini_model = ChatOpenAI(model="gpt-4o-mini-2024-07-18", temperature=0)

        # register any models
        SQLModel.metadata.create_all(self.engine)

        # re-attach any databases that were connected to the client
        self._attach_databases()

    def _attach_databases(self):
        """iterates to connected datasources and attaches them to the client"""
        with Session(self.engine) as session:
            datasources = session.exec(select(DataSource)).all()
            for datasource in datasources:
                datasource._config = self.config
                datasource._attach()

    def create_datasource(
        self,
        connection_uri: str,
        name: Optional[str] = None,
        dtypes: Optional[dict[str, str]] = None,
    ) -> DataSource:
        """Creates a new datasource object and connects it to Relta

        Args:
            connection_uri (str): The connection_uri for the datasource
            name (str, optional): The datasource name. If none is provided Relta will assign a name
            dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
                The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview). *Only for CSVs*.

        Raises:
            duckdb.CatalogException: Raised if a datasource with given name already exists
            duckdb.BinderException: Raised if Relta cannot connect to the given database.

        Returns:
            DataSource: The newly created Datasource object
        """
        with Session(self.engine) as session:
            session.expire_on_commit = False

            if connection_uri.endswith(".csv"):
                type = DataSourceType.CSV
            elif connection_uri.endswith(".parquet"):
                type = DataSourceType.PARQUET
            elif connection_uri.startswith("postgres"):
                type = DataSourceType.POSTGRES
            elif connection_uri.startswith("mysql"):
                type = DataSourceType.MYSQL
            elif connection_uri.endswith(".duckdb") or connection_uri.endswith(".ddb"):
                type = DataSourceType.DUCKDB

            if name is None:
                name = DataSource._infer_name_if_none(
                    type=type, connection_uri=connection_uri
                )
            logfire.info(
                "Creating datasource {name} with type {type}", name=name, type=type
            )
            ds = session.exec(select(DataSource).where(DataSource.name == name)).first()

            if ds is not None:
                raise DuplicateResourceException(
                    f"Datasource with {name} exist. Consider using get_datasource or get_or_create_datasource."
                )

            datasource = DataSource(type=type, connection_uri=connection_uri, name=name)
            datasource._config = self.config

            try:
                datasource.connect(dtypes=dtypes)
            except duckdb.CatalogException:
                logfire.error("A table with the same name already exists in Relta")
                raise duckdb.CatalogException(
                    "A table with the same name already exists in Relta"
                )
            except duckdb.BinderException:
                logfire.error("A table with the same name already exists in Relta")
                raise duckdb.BinderException(
                    "A database with same name is already connected to Relta"
                )

            datasource._semantic_layer = SemanticLayer(datasource, self.config)
            session.add(datasource)
            session.commit()
            return datasource

    def get_datasource(self, name: str) -> DataSource:
        """Returns a datasource object with given name or id

        Args:
            name (str): The name of the datasource. Must be passed in.

        Returns:
            DataSource: The Datasource object or None if it does not exist
        """
        with Session(self.engine) as session:
            ds = session.exec(select(DataSource).where(DataSource.name == name)).first()
            if ds is None:
                return None
            ds._config = self.config
            ds._semantic_layer = SemanticLayer(ds, self.config)
            return ds

    def get_or_create_datasource(self, name: str, connection_uri: str) -> DataSource:
        """If a datasource with the same name and same connection_uri exist we return it. Otherwise create a new one.

        Args:
            name (str): the name of the datasource to get or create
            connection_uri (str): the connection_uri to the datasource to get or create

        Returns:
            DataSource: The existng datasource or the new one
        """
        with Session(self.engine) as session:
            ds = session.exec(select(DataSource).where(DataSource.name == name)).first()

            if ds is None:
                return self.create_datasource(name=name, connection_uri=connection_uri)

            ds._config = self.config
            ds._semantic_layer = SemanticLayer(ds, self.config)
            return ds

    def delete_datasource(self, name: str) -> None:
        """Deletes DataSource and all associated Chat objects from Relta. Cannot be reversed.

        Args:
            name: the datasource name

        Raises:
            ValueError: If DataSource does not exist
        """

        datasource = self.get_datasource(name)
        if datasource is None:
            raise ValueError("Datasource does not exist")

        datasource._disconnect()

        # we have to delete Chats and split into two commits because of a limitation of duckdb indexes.
        with Session(self.engine) as session:
            session.expire_on_commit = False
            chats = session.exec(
                select(Chat).where(Chat.datasource_id == datasource.id)
            ).all()
            for chat in chats:
                session.delete(chat)
            session.commit()

        with Session(self.engine) as session:
            session.delete(datasource)
            session.commit()

    def get_sources(self) -> list[DataSource]:
        """Method to get all connected datasource objects

        Returns:
            list[DataSource]: A list containing DataSource objects for all connected sources
        """
        with Session(self.engine) as session:
            ds_lst = session.exec(select(DataSource)).all()
            for ds in ds_lst:
                ds._config = self.config
                ds._semantic_layer = SemanticLayer(ds, self.config)
            return ds_lst

    def show_sources(self) -> None:
        """Prints a table of all connected datasources to the console"""
        with Session(self.engine) as session:
            datasources = session.exec(select(DataSource)).all()
            datasources_dicts = [
                dict(sorted(datasource.model_dump().items()))
                for datasource in datasources
            ]
            headers = sorted(DataSource.model_fields())
            rows = [datasource.values() for datasource in datasources_dicts]

            print(tabulate(rows, headers=headers, tablefmt="pretty"))

    def create_chat(self, datasource: DataSource) -> Chat:
        """Creates a chat with the given DataSource"""
        logfire.info("New chat created with datasource {name}", name=datasource.name)
        chat = Chat(datasource_id=datasource.id, datasource=datasource)
        chat._config = self.config
        chat._responses = []
        with Session(self.engine) as session:
            session.expire_on_commit = False
            session.add(chat)
            session.commit()
        return chat

    def list_chats(self, datasource: DataSource) -> list[Chat]:
        """List all Chat objects for a given DataSource"""
        with Session(self.engine) as session:
            chats = session.exec(
                select(Chat).where(Chat.datasource_id == datasource.id)
            ).all()

            for chat in chats:
                chat._config = self.config
                chat._responses = []
                chat.datasource = datasource

            return chats

    def get_chat(self, datasource: DataSource, id: int) -> Chat:
        with Session(self.engine) as session:
            chat = session.exec(
                select(Chat)
                .where(Chat.id == id)
                .where(Chat.datasource_id == datasource.id)
            ).first()
            if chat is None:
                raise NameError(
                    f"Chat '{id}' does not exist on DataSource {datasource.name}"
                )
            chat._config = self.config
            chat._responses = []  # TODO: make responses load in when getting a chat again
            chat.datasource = datasource
            return chat

create_chat(datasource)

Creates a chat with the given DataSource

Source code in src/relta/client.py
def create_chat(self, datasource: DataSource) -> Chat:
    """Creates a chat with the given DataSource"""
    logfire.info("New chat created with datasource {name}", name=datasource.name)
    chat = Chat(datasource_id=datasource.id, datasource=datasource)
    chat._config = self.config
    chat._responses = []
    with Session(self.engine) as session:
        session.expire_on_commit = False
        session.add(chat)
        session.commit()
    return chat

create_datasource(connection_uri, name=None, dtypes=None)

Creates a new datasource object and connects it to Relta

Parameters:

Name Type Description Default
connection_uri str

The connection_uri for the datasource

required
name str

The datasource name. If none is provided Relta will assign a name

None
dtypes dict

Map of column names to datatypes, overrides a column's auto-detected type. The datatypes should be DuckDB datatypes. Only for CSVs.

None

Raises:

Type Description
CatalogException

Raised if a datasource with given name already exists

BinderException

Raised if Relta cannot connect to the given database.

Returns:

Name Type Description
DataSource DataSource

The newly created Datasource object

Source code in src/relta/client.py
def create_datasource(
    self,
    connection_uri: str,
    name: Optional[str] = None,
    dtypes: Optional[dict[str, str]] = None,
) -> DataSource:
    """Creates a new datasource object and connects it to Relta

    Args:
        connection_uri (str): The connection_uri for the datasource
        name (str, optional): The datasource name. If none is provided Relta will assign a name
        dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
            The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview). *Only for CSVs*.

    Raises:
        duckdb.CatalogException: Raised if a datasource with given name already exists
        duckdb.BinderException: Raised if Relta cannot connect to the given database.

    Returns:
        DataSource: The newly created Datasource object
    """
    with Session(self.engine) as session:
        session.expire_on_commit = False

        if connection_uri.endswith(".csv"):
            type = DataSourceType.CSV
        elif connection_uri.endswith(".parquet"):
            type = DataSourceType.PARQUET
        elif connection_uri.startswith("postgres"):
            type = DataSourceType.POSTGRES
        elif connection_uri.startswith("mysql"):
            type = DataSourceType.MYSQL
        elif connection_uri.endswith(".duckdb") or connection_uri.endswith(".ddb"):
            type = DataSourceType.DUCKDB

        if name is None:
            name = DataSource._infer_name_if_none(
                type=type, connection_uri=connection_uri
            )
        logfire.info(
            "Creating datasource {name} with type {type}", name=name, type=type
        )
        ds = session.exec(select(DataSource).where(DataSource.name == name)).first()

        if ds is not None:
            raise DuplicateResourceException(
                f"Datasource with {name} exist. Consider using get_datasource or get_or_create_datasource."
            )

        datasource = DataSource(type=type, connection_uri=connection_uri, name=name)
        datasource._config = self.config

        try:
            datasource.connect(dtypes=dtypes)
        except duckdb.CatalogException:
            logfire.error("A table with the same name already exists in Relta")
            raise duckdb.CatalogException(
                "A table with the same name already exists in Relta"
            )
        except duckdb.BinderException:
            logfire.error("A table with the same name already exists in Relta")
            raise duckdb.BinderException(
                "A database with same name is already connected to Relta"
            )

        datasource._semantic_layer = SemanticLayer(datasource, self.config)
        session.add(datasource)
        session.commit()
        return datasource

delete_datasource(name)

Deletes DataSource and all associated Chat objects from Relta. Cannot be reversed.

Parameters:

Name Type Description Default
name str

the datasource name

required

Raises:

Type Description
ValueError

If DataSource does not exist

Source code in src/relta/client.py
def delete_datasource(self, name: str) -> None:
    """Deletes DataSource and all associated Chat objects from Relta. Cannot be reversed.

    Args:
        name: the datasource name

    Raises:
        ValueError: If DataSource does not exist
    """

    datasource = self.get_datasource(name)
    if datasource is None:
        raise ValueError("Datasource does not exist")

    datasource._disconnect()

    # we have to delete Chats and split into two commits because of a limitation of duckdb indexes.
    with Session(self.engine) as session:
        session.expire_on_commit = False
        chats = session.exec(
            select(Chat).where(Chat.datasource_id == datasource.id)
        ).all()
        for chat in chats:
            session.delete(chat)
        session.commit()

    with Session(self.engine) as session:
        session.delete(datasource)
        session.commit()

get_datasource(name)

Returns a datasource object with given name or id

Parameters:

Name Type Description Default
name str

The name of the datasource. Must be passed in.

required

Returns:

Name Type Description
DataSource DataSource

The Datasource object or None if it does not exist

Source code in src/relta/client.py
def get_datasource(self, name: str) -> DataSource:
    """Returns a datasource object with given name or id

    Args:
        name (str): The name of the datasource. Must be passed in.

    Returns:
        DataSource: The Datasource object or None if it does not exist
    """
    with Session(self.engine) as session:
        ds = session.exec(select(DataSource).where(DataSource.name == name)).first()
        if ds is None:
            return None
        ds._config = self.config
        ds._semantic_layer = SemanticLayer(ds, self.config)
        return ds

get_or_create_datasource(name, connection_uri)

If a datasource with the same name and same connection_uri exist we return it. Otherwise create a new one.

Parameters:

Name Type Description Default
name str

the name of the datasource to get or create

required
connection_uri str

the connection_uri to the datasource to get or create

required

Returns:

Name Type Description
DataSource DataSource

The existng datasource or the new one

Source code in src/relta/client.py
def get_or_create_datasource(self, name: str, connection_uri: str) -> DataSource:
    """If a datasource with the same name and same connection_uri exist we return it. Otherwise create a new one.

    Args:
        name (str): the name of the datasource to get or create
        connection_uri (str): the connection_uri to the datasource to get or create

    Returns:
        DataSource: The existng datasource or the new one
    """
    with Session(self.engine) as session:
        ds = session.exec(select(DataSource).where(DataSource.name == name)).first()

        if ds is None:
            return self.create_datasource(name=name, connection_uri=connection_uri)

        ds._config = self.config
        ds._semantic_layer = SemanticLayer(ds, self.config)
        return ds

get_sources()

Method to get all connected datasource objects

Returns:

Type Description
list[DataSource]

list[DataSource]: A list containing DataSource objects for all connected sources

Source code in src/relta/client.py
def get_sources(self) -> list[DataSource]:
    """Method to get all connected datasource objects

    Returns:
        list[DataSource]: A list containing DataSource objects for all connected sources
    """
    with Session(self.engine) as session:
        ds_lst = session.exec(select(DataSource)).all()
        for ds in ds_lst:
            ds._config = self.config
            ds._semantic_layer = SemanticLayer(ds, self.config)
        return ds_lst

list_chats(datasource)

List all Chat objects for a given DataSource

Source code in src/relta/client.py
def list_chats(self, datasource: DataSource) -> list[Chat]:
    """List all Chat objects for a given DataSource"""
    with Session(self.engine) as session:
        chats = session.exec(
            select(Chat).where(Chat.datasource_id == datasource.id)
        ).all()

        for chat in chats:
            chat._config = self.config
            chat._responses = []
            chat.datasource = datasource

        return chats

show_sources()

Prints a table of all connected datasources to the console

Source code in src/relta/client.py
def show_sources(self) -> None:
    """Prints a table of all connected datasources to the console"""
    with Session(self.engine) as session:
        datasources = session.exec(select(DataSource)).all()
        datasources_dicts = [
            dict(sorted(datasource.model_dump().items()))
            for datasource in datasources
        ]
        headers = sorted(DataSource.model_fields())
        rows = [datasource.values() for datasource in datasources_dicts]

        print(tabulate(rows, headers=headers, tablefmt="pretty"))

Configuration

Bases: BaseSettings

Configuration class for Relta

Any attributes ending with _dir_path will be created when a Client object is initialized.

Source code in src/relta/config.py
class Configuration(BaseSettings):
    """Configuration class for Relta

    Any attributes ending with `_dir_path` will be created when a `Client` object is initialized.

    """

    # Unfortunately, default values cannot be `None`, so you will have to add some extra logic when using config variables that are optional.

    model_config = SettingsConfigDict(
        toml_file="relta.toml",
        pyproject_toml_table_header=("tool", "relta"),
        env_file=".env",
        extra="ignore",
    )

    @classmethod
    def settings_customise_sources(
        cls,
        settings_cls: Type[BaseSettings],
        init_settings: PydanticBaseSettingsSource,
        env_settings: PydanticBaseSettingsSource,
        dotenv_settings: PydanticBaseSettingsSource,
        file_secret_settings: PydanticBaseSettingsSource,
    ) -> tuple[PydanticBaseSettingsSource, ...]:
        return (
            init_settings,
            env_settings,
            TomlConfigSettingsSource(settings_cls),
            PyprojectTomlConfigSettingsSource(settings_cls),
            dotenv_settings,
            file_secret_settings,
        )

    relta_dir_path: Path = Path(".relta")
    relta_semantic_layer_dir_path: Path = relta_dir_path / "semantic_layer"
    transient_data_dir_path: Path = relta_dir_path / "data"
    relta_internal_path: Path = relta_dir_path / "relta_internal.duckdb"
    relta_data_path: Path = relta_dir_path / "relta_data.duckdb"
    chat_memory_path: Path = relta_dir_path / "chat_memory.sqlite"
    semantic_memory_path: Path = relta_dir_path / "semantic_memory.sqlite"
    storage_path: Path = relta_dir_path / "relta.sqlite"
    storage_table: str = "messages"
    storage_session_id_field: str = "session_id"
    openai_key: str = Field(alias="OPENAI_API_KEY")
    json_dumps_kwargs: dict = {"indent": 2}
    yaml_dumps_kwargs: dict = {"sort_keys": False}

    encryption_key: str = Field(alias="ENCRYPTION_KEY", default="")
    auto_generate_encryption_key: bool = Field(
        alias="AUTO_GENERATE_ENCRYPTION_KEY", default=True
    )

    debug: bool = False  # TODO: Further implement debug mode.
    anonymized_telemetry: bool = True
    low_cardinality_cutoff: int = 100

    github_token: str = Field(alias="GITHUB_TOKEN", default="")
    github_repo: str = Field(alias="GITHUB_REPO", default="")
    github_base_branch: str = Field(alias="GITHUB_BASE_BRANCH", default="")

    logfire_token: str = Field(alias="LOGFIRE_TOKEN", default="")

DataSource

Bases: SQLModel

Source code in src/relta/datasource/datasource.py
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
class DataSource(SQLModel, table=True):
    conn: ClassVar[
        duckdb.DuckDBPyConnection
    ]  # populated when a Client is initialized. ClassVar is not serialized by SQLModel
    # this connection is for the external data source, not for accessing the Relta replica.

    id: Optional[int] = Field(default=source_id_seq.next_value(), primary_key=True)
    type: DataSourceType
    connection_uri: str = Field(sa_column=Column(String))
    name: str
    last_hydrated: Optional[datetime] = Field(default=None)

    # These private fields are populated on creation or load.
    # They aren't validated by SQLModel (even if you pass them in on __init__ as private attributes, no matter the ConfigDict(extra="allow"))
    _config: Configuration
    _semantic_layer: SemanticLayer

    chats: list["Chat"] = Relationship(back_populates="datasource")  # type: ignore # noqa: F821 # avoid circular import w quotes

    _encryption_key: ClassVar[bytes]  # Class-level encryption key

    @property
    def semantic_layer(self):
        """The semantic layer of the `DataSource`. Populated by `Client`"""
        return self._semantic_layer

    @property
    def decrypted_uri(self) -> str:
        """Access the decrypted connection URI"""
        return self._decrypt_uri(self.connection_uri, self._fernet)

    @classmethod
    def set_encryption_key(cls, key: bytes):
        """Set the encryption key for all DataSource instances"""
        cls._encryption_key = key
        cls._fernet = Fernet(key)

    def __init__(self, **data):
        if "connection_uri" in data:
            # Encrypt the connection URI before storing
            print(f"key is {self._encryption_key}")
            data["connection_uri"] = self._encrypt_uri(
                data["connection_uri"], self._fernet
            )
        super().__init__(**data)

    def _attach(self):
        """Attaches the database to Relta. This is called at client initialization.

        Raises:
            NotImplementedError: If the data source type is not supported
        """
        if self.type == DataSourceType.POSTGRES:
            self.conn.sql("INSTALL POSTGRES")
            self._connect_postgres()
        elif self.type == DataSourceType.MYSQL:
            self._connect_mysql()
        elif self.type == DataSourceType.CSV:
            self._connect_csv()
        elif self.type == DataSourceType.DUCKDB:
            self._connect_duckdb()
        else:
            raise NotImplementedError

    def _connect_duckdb(self):
        """Private method used to connect to a DuckDB datasource"""
        try:
            self.conn.sql(f"ATTACH '{self.decrypted_uri}' AS {self.name}")

            self.conn.sql(
                f"ATTACH '{self._config.transient_data_dir_path}/{self.name}.duckdb' AS transient_{self.name}"
            )

        except duckdb.BinderException as e:
            logfire.error(e)
            logfire.error(
                f"Unable to attach to DuckDB database with connection URI {self.decrypted_uri}. This could be because the connection_uri is wrong or database with same name already exists in Relta."
            )
            raise duckdb.BinderException

    def _connect_postgres(self):
        """Private method used to connect to a Postgres data source"""
        parsed_uri = urlparse(self.decrypted_uri)
        try:
            self.conn.sql(
                f"ATTACH 'dbname={parsed_uri.path.lstrip('/')} {f'host={parsed_uri.hostname}' if parsed_uri.hostname else ''} {f'user={parsed_uri.username}' if parsed_uri.username else ''} {f'password={parsed_uri.password}' if parsed_uri.password else ''} {f'port={parsed_uri.port}' if parsed_uri.port else ''}' AS {self.name} (TYPE POSTGRES, READ_ONLY)"
            )

            self.conn.sql(
                f"ATTACH '{self._config.transient_data_dir_path}/{self.name}.duckdb' AS transient_{self.name}"
            )
        except duckdb.BinderException as e:
            logfire.error(e)
            logfire.error(
                f"Unable to attach to Postgres database with name {self.name}. This could be because the connection_uri is wrong or database with same name already exists in Relta."
            )
            raise duckdb.BinderException

    def _connect_mysql(self):
        raise NotImplementedError

    def _connect_csv(self, dtypes: Optional[dict[str, str]] = None):
        """Private method used to connect to a CSV data source and create a table in Relta

        Args:
            dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
                The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview).

        Raises:
            duckdb.CatalogException: raised if a table with the same name already exists in Relta
        """
        try:
            self.conn.sql(
                f"ATTACH '{self._config.transient_data_dir_path}/{self.name}.duckdb' AS transient_{self.name}"
            )
            self._load_csv(dtypes)
        except duckdb.CatalogException as e:
            logfire.error(e)
            logfire.error(
                f"Table with name {self.name} already exists in Relta. Please choose a different name or consider refreshing data using rehydrate()"
            )
            raise duckdb.CatalogException

    def _connect_parquet(self):
        raise NotImplementedError

    def connect(self, dtypes: Optional[dict[str, str]] = None):
        """Creates a connection to the given data source. This allows Relta to query the underlying data (e.g. read schema) but does not copy data into Relta.

        Args:
            dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
                *Only for CSVs*. The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview).

        Raises:
            duckdb.CatalogException: If a table with the same name is already connected to Relta
        """
        if self.type == DataSourceType.CSV:
            self._connect_csv(dtypes)
        elif self.type == DataSourceType.PARQUET:
            self._connect_parquet()
        elif self.type == DataSourceType.POSTGRES:
            self._connect_postgres()
        elif self.type == DataSourceType.MYSQL:
            self._connect_mysql()
        elif self.type == DataSourceType.DUCKDB:
            self._connect_duckdb()

    def _disconnect(self):
        """Disconnects the data source from Relta

        Raises:
            duckdb.CatalogException: If the underlying data source does not exist in Relta
        """
        self.conn.sql(
            "ATTACH IF NOT EXISTS ':memory:' AS memory_db"
        )  # this is to gaurd in case the DB we are deleting is the default database
        try:
            if (
                self.type == DataSourceType.POSTGRES
                or self.type == DataSourceType.DUCKDB
            ):
                self.conn.sql("USE memory_db")
                self.conn.sql(f"DETACH {self.name}")
                os.remove(f"{self._config.transient_data_dir_path}/{self.name}.duckdb")
            elif self.type == DataSourceType.CSV:
                self.conn.sql("USE memory_db")
                self.conn.sql(f"DETACH  transient_{self.name}")
                os.remove(f"{self._config.transient_data_dir_path}/{self.name}.duckdb")

        except duckdb.CatalogException as e:
            logfire.error(e)
            logfire.error(
                f"Table with name {self.name} does not exist in Relta. Please check the name and try again"
            )
            raise duckdb.CatalogException

    def load(self):
        """Updates the data in Relta from the underlying data source"""
        if self.type == DataSourceType.POSTGRES:
            self._load_postgres()
        elif self.type == DataSource.DUCKDB:
            self._load_postgres()
        elif self.type == DataSourceType.CSV:
            self._load_csv()

    def _load_csv(self, dtypes: Optional[dict[str, str]] = None):
        self.last_hydrated = datetime.now()
        self.conn.sql(f"USE transient_{self.name}")
        if dtypes:
            create_table_sql = f"CREATE OR REPLACE TABLE {self.name} AS SELECT * from read_csv('{self.decrypted_uri}', types = {dtypes})"
        else:
            create_table_sql = f"CREATE OR REPLACE TABLE {self.name} AS SELECT * from read_csv('{self.decrypted_uri}')"

        self.conn.sql(create_table_sql)
        self.last_hydrated = datetime.now()

    def _load_postgres(self):
        self.conn.sql("USE relta_data")
        self.conn.sql(
            f"ATTACH IF NOT EXISTS '{self._config.transient_data_dir_path}/transient_{self.name}.duckdb' As transient_{self.name}"
        )

        self.conn.sql(f"USE transient_{self.name}")
        for metric in self._semantic_layer.metrics:
            self.conn.sql(
                f"CREATE TABLE IF NOT EXISTS {metric.name} AS {metric.sql_to_underlying_datasource}"
            )
            # for column in metric.dimensions:
            #    self.conn.sql(f"CREATE OR REPLACE VIEW {metric.name} AS SELECT {column}, {metric.name} FROM {self.name}")

        self.last_hydrated = datetime.now()  # TODO: this needs to be written to the database, but that is a client operation... what to do about this?
        # TODO: when to detach? it should be after hydrating?

    @deprecated(reason="Use DataSource().semantic_layer property instead")
    def create_semantic_layer(self) -> SemanticLayer:
        """Returns the semantic model of the data source"""
        self._semantic_layer = SemanticLayer(self, self._config)
        return self._semantic_layer

    @deprecated(reason="Use DataSource().semantic_layer property instead")
    def get_semantic_layer(self) -> SemanticLayer:
        return self._semantic_layer

    def _get_ddl(self) -> str:
        """Returns the DDL of the data source"""

        if self.type == DataSourceType.POSTGRES or self.type == DataSourceType.DUCKDB:
            ddl_list = self.conn.sql(
                f"select sql from duckdb_tables() where database_name='{self.name}' and schema_name != 'information_schema' and schema_name != 'pg_catalog'"
            ).fetchall()
            ddl = "\n".join([ddl[0] for ddl in ddl_list])
        elif self.type == DataSourceType.CSV:
            ddl = self.conn.sql(
                f"select sql from duckdb_tables() where table_name='{self.name}'"
            ).fetchone()[0]  # self.conn.sql(f"DESCRIBE {self.name}")

        return ddl

    def _create_metrics(self):
        self.conn.sql("USE relta_data")
        for metric in self._semantic_layer.metrics:
            # fully_formed_sql = self._append_db_to_table_name(metric.sql_to_underlying_datasource, f'transient_{self.name}')
            self.conn.sql(
                f"CREATE OR REPLACE VIEW {metric.name} AS select * from transient_{self.name}.{metric.name}"
            )

    def _execute_datasource_sql(self, sql: str):
        """Run SQL against the underlying datasource"""
        if self.type == DataSourceType.CSV:
            self.conn.sql(f"USE transient_{self.name}")
            return self.conn.sql(sql)
        else:
            raise NotImplementedError

    def _execute_sql(self, sql: str):
        self.conn.sql("USE relta_data")
        return self.conn.sql(sql)

    def _get_transient_ddl(self):
        # self.conn.sql(f"USE transient_{self.name}")
        return self.conn.sql(
            f"SELECT * FROM duckdb_tables() where database_name='transient_{self.name}'"
        ).fetchall()

    @staticmethod
    def _append_db_to_table_name(original_sql: str, db_name: str) -> str:
        """In DuckDB we need fully formed table and column names with database name appended. This method creates those.

        Args:
            original_sql (str): the sql we will be modifying

        Returns:
            str: The SQL statement with db name appended to table names
        """
        fully_formed_sql = original_sql
        table_names = list(parse_one(fully_formed_sql).find_all(exp.Table))
        tables = [
            str(table).partition(" ")[0] for table in table_names
        ]  # this is bc sqlglot returns the table name as '{TABLE NAME} AS {ALIAS}'
        tables = set(tables)
        for table in tables:
            fully_formed_sql = re.sub(
                r"\b" + re.escape(table) + r"\b",
                f"{db_name}.{table}",
                fully_formed_sql,
            )

        return fully_formed_sql

    def _create_transient_tables(self, calculate_statistics: bool = True):
        """Creates the transient tables in DuckDB

        Args:
            calculate_statistics (bool, optional): Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.
        """
        self.conn.sql(f"USE transient_{self.name}")

        if self.type == DataSourceType.POSTGRES or self.type == self.type.DUCKDB:
            for metric in self._semantic_layer.metrics:
                fully_formed_sql_to_underlying_source = self._append_db_to_table_name(
                    metric.sql_to_underlying_datasource, self.name
                )
                self.conn.sql(
                    f"CREATE OR REPLACE TABLE {metric.name} AS {fully_formed_sql_to_underlying_source}"
                )

        elif self.type == DataSourceType.CSV:
            for metric in self._semantic_layer.metrics:
                fully_formed_sql_to_underlying_source = self._append_db_to_table_name(
                    metric.sql_to_underlying_datasource, f"transient_{self.name}"
                )
                self.conn.sql(
                    f"CREATE OR REPLACE TABLE {metric.name} AS {fully_formed_sql_to_underlying_source}"
                )

        if calculate_statistics:
            # the following code identifies low cardinality columns
            for metric in self.semantic_layer.metrics:
                for dimension in metric.dimensions:
                    dimension.categories = []
                    cardinality = self.conn.sql(
                        f"SELECT approx_count_distinct({dimension.name}) from {metric.name}"
                    ).fetchone()[0]
                    if cardinality < 100 and not dimension.skip_categorical_load:
                        categories = [
                            value[0]
                            for value in self.conn.sql(
                                f"SELECT DISTINCT {dimension.name} FROM {metric.name}"
                            ).fetchall()
                        ]
                        dimension.categories = categories

    def deploy(self, statistics: bool = True):
        """
        Deploys the semantic layer to the data source.

        Args:
            statistics (bool, optional): Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.
        """
        logfire.span(
            "deploying semantic layer {semantic_layer}",
            semantic_layer=self.semantic_layer,
        )
        self._drop_removed_metrics()
        self._create_transient_tables(statistics)
        self._create_metrics()
        logfire.info("semantic layer deployed")

    def _drop_removed_metrics(self):
        """Checks the current list of metrics against views and transient tables. Drop them if they are no longer in the semantic layer"""
        self.conn.sql("use relta_data")
        if self.type == DataSourceType.CSV:  # on CSV we copy the entire data as a table
            existing_metrics = self.conn.sql(
                f"SELECT table_name FROM duckdb_tables() where database_name='transient_{self.name}' and table_name!='{self.name}'"
            ).fetchall()
        else:
            existing_metrics = self.conn.sql(
                f"SELECT table_name FROM duckdb_tables() where database_name='transient_{self.name}'"
            ).fetchall()

        existing_metric_names = [metric[0] for metric in existing_metrics]
        metric_names_to_keep = [metric.name for metric in self._semantic_layer.metrics]

        for metric_name in existing_metric_names:
            if metric_name not in metric_names_to_keep:
                self.conn.sql(
                    f"DROP TABLE IF EXISTS transient_{self.name}.{metric_name}"
                )
                self.conn.sql(f"DROP VIEW IF EXISTS {metric_name}")

    @staticmethod
    def _infer_name_if_none(type: DataSourceType, connection_uri: str) -> str:
        if type is DataSourceType.CSV or type is DataSourceType.DUCKDB:
            name = (
                connection_uri.split("/")[-1]
                .split(".")[0]
                .replace(" ", "_")
                .replace("-", "_")
            )
        elif type is DataSourceType.POSTGRES or DataSourceType.DUCKDB:
            name = connection_uri.split("/")[-1]

        return name

    @staticmethod
    def _encrypt_uri(connection_uri: str, key) -> str:
        """Encrypt a connection URI"""
        encrypted = key.encrypt(connection_uri.encode())
        return b64encode(encrypted).decode()

    @staticmethod
    def _decrypt_uri(connection_uri: str, key) -> str:
        """Decrypt the stored connection URI"""
        encrypted = b64decode(connection_uri.encode())
        return key.decrypt(encrypted).decode()

decrypted_uri: str property

Access the decrypted connection URI

semantic_layer property

The semantic layer of the DataSource. Populated by Client

connect(dtypes=None)

Creates a connection to the given data source. This allows Relta to query the underlying data (e.g. read schema) but does not copy data into Relta.

Parameters:

Name Type Description Default
dtypes dict

Map of column names to datatypes, overrides a column's auto-detected type. Only for CSVs. The datatypes should be DuckDB datatypes.

None

Raises:

Type Description
CatalogException

If a table with the same name is already connected to Relta

Source code in src/relta/datasource/datasource.py
def connect(self, dtypes: Optional[dict[str, str]] = None):
    """Creates a connection to the given data source. This allows Relta to query the underlying data (e.g. read schema) but does not copy data into Relta.

    Args:
        dtypes (dict, optional): Map of column names to datatypes, overrides a column's auto-detected type.
            *Only for CSVs*. The datatypes should be [DuckDB datatypes](https://duckdb.org/docs/sql/data_types/overview).

    Raises:
        duckdb.CatalogException: If a table with the same name is already connected to Relta
    """
    if self.type == DataSourceType.CSV:
        self._connect_csv(dtypes)
    elif self.type == DataSourceType.PARQUET:
        self._connect_parquet()
    elif self.type == DataSourceType.POSTGRES:
        self._connect_postgres()
    elif self.type == DataSourceType.MYSQL:
        self._connect_mysql()
    elif self.type == DataSourceType.DUCKDB:
        self._connect_duckdb()

create_semantic_layer()

Returns the semantic model of the data source

Source code in src/relta/datasource/datasource.py
@deprecated(reason="Use DataSource().semantic_layer property instead")
def create_semantic_layer(self) -> SemanticLayer:
    """Returns the semantic model of the data source"""
    self._semantic_layer = SemanticLayer(self, self._config)
    return self._semantic_layer

deploy(statistics=True)

Deploys the semantic layer to the data source.

Parameters:

Name Type Description Default
statistics bool

Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.

True
Source code in src/relta/datasource/datasource.py
def deploy(self, statistics: bool = True):
    """
    Deploys the semantic layer to the data source.

    Args:
        statistics (bool, optional): Calculate statistics (i.e. low cardinality columns) for each metric. Defaults to True.
    """
    logfire.span(
        "deploying semantic layer {semantic_layer}",
        semantic_layer=self.semantic_layer,
    )
    self._drop_removed_metrics()
    self._create_transient_tables(statistics)
    self._create_metrics()
    logfire.info("semantic layer deployed")

load()

Updates the data in Relta from the underlying data source

Source code in src/relta/datasource/datasource.py
def load(self):
    """Updates the data in Relta from the underlying data source"""
    if self.type == DataSourceType.POSTGRES:
        self._load_postgres()
    elif self.type == DataSource.DUCKDB:
        self._load_postgres()
    elif self.type == DataSourceType.CSV:
        self._load_csv()

set_encryption_key(key) classmethod

Set the encryption key for all DataSource instances

Source code in src/relta/datasource/datasource.py
@classmethod
def set_encryption_key(cls, key: bytes):
    """Set the encryption key for all DataSource instances"""
    cls._encryption_key = key
    cls._fernet = Fernet(key)

Chat

Bases: SQLModel

A Thread of conversation.

Contains metadata around a thread and exposes simple calls to the agent.

Source code in src/relta/chat/chat.py
class Chat(SQLModel, table=True):
    """A Thread of conversation.

    Contains metadata around a thread and exposes simple calls to the agent.
    """

    model_config = ConfigDict(extra="allow")

    id: Optional[int] = Field(default=chat_id_seq.next_value(), primary_key=True)

    datasource_id: int = Field(foreign_key="datasource.id")
    datasource: DataSource = Relationship(back_populates="chats")
    # _config: Configuration  # set by Client when this is created (or loaded) # why does this not work?

    agent: ClassVar[SQLAgent]  # populated when a Client is initialized.
    _responses: list["Response"]  # populated when a Client is initialized.

    @property
    def responses(self) -> list[Response]:
        if self._responses is None:
            self._responses = []
        return self._responses

    def _get_messages(self) -> list[AnyMessage]:
        checkpoint = self.agent.checkpointer.get(
            {"configurable": {"thread_id": self.id}}
        )
        return checkpoint["channel_values"]["messages"]

    def prompt(
        self,
        s: str,
        debug=False,
        mode: Literal[
            AgentMode.COMPLETE, AgentMode.DATA_ONLY, AgentMode.SQL_ONLY
        ] = "complete",
        **kwargs,
    ) -> Response:
        """Ask a question on the thread.

        Args:
            s (str): question
            debug (bool, optional): If True, runs the graph in debug mode. If None, uses the default value set in the config. Defaults to None.
            mode (Literal['complete', 'data_only', 'sql_only'], optional): The mode in which to run the agent. Defaults to 'complete'.
            **kwargs: additional keyword arguments to pass to the agent. See `SQLAgent.invoke` for arguments.
        Raises:
            ValueError: if the Chat has no id (it has not been persisted)

        Returns:
            Response: the response from the agent
        """
        logfire.span(
            "New prompt submitted to chat on datasource {datasource_name}",
            datasource_name=self.datasource.name,
        )
        if self.id is None:
            raise ValueError("Chat has no id")

        final_state = self.agent.invoke(
            prompt=s,
            datasource=self.datasource,
            thread_id=self.id,
            debug=debug,
            mode=mode,
            **kwargs,
        )

        logfire.info("Response from chat")
        resp = Response(
            chat=self,
            id=len(self.responses),
            message=final_state["messages"][-1] if "messages" in final_state else None,
            text=final_state["messages"][-1].content
            if "messages" in final_state
            else None,
            sql=final_state["sql_generation"].sql
            if final_state.get("sql_generation", None) is not None
            else None,
            sql_result=final_state.get("query_results", None),
        )

        self._responses.append(resp)
        return resp

prompt(s, debug=False, mode='complete', **kwargs)

Ask a question on the thread.

Parameters:

Name Type Description Default
s str

question

required
debug bool

If True, runs the graph in debug mode. If None, uses the default value set in the config. Defaults to None.

False
mode Literal['complete', 'data_only', 'sql_only']

The mode in which to run the agent. Defaults to 'complete'.

'complete'
**kwargs

additional keyword arguments to pass to the agent. See SQLAgent.invoke for arguments.

{}

Raises: ValueError: if the Chat has no id (it has not been persisted)

Returns:

Name Type Description
Response Response

the response from the agent

Source code in src/relta/chat/chat.py
def prompt(
    self,
    s: str,
    debug=False,
    mode: Literal[
        AgentMode.COMPLETE, AgentMode.DATA_ONLY, AgentMode.SQL_ONLY
    ] = "complete",
    **kwargs,
) -> Response:
    """Ask a question on the thread.

    Args:
        s (str): question
        debug (bool, optional): If True, runs the graph in debug mode. If None, uses the default value set in the config. Defaults to None.
        mode (Literal['complete', 'data_only', 'sql_only'], optional): The mode in which to run the agent. Defaults to 'complete'.
        **kwargs: additional keyword arguments to pass to the agent. See `SQLAgent.invoke` for arguments.
    Raises:
        ValueError: if the Chat has no id (it has not been persisted)

    Returns:
        Response: the response from the agent
    """
    logfire.span(
        "New prompt submitted to chat on datasource {datasource_name}",
        datasource_name=self.datasource.name,
    )
    if self.id is None:
        raise ValueError("Chat has no id")

    final_state = self.agent.invoke(
        prompt=s,
        datasource=self.datasource,
        thread_id=self.id,
        debug=debug,
        mode=mode,
        **kwargs,
    )

    logfire.info("Response from chat")
    resp = Response(
        chat=self,
        id=len(self.responses),
        message=final_state["messages"][-1] if "messages" in final_state else None,
        text=final_state["messages"][-1].content
        if "messages" in final_state
        else None,
        sql=final_state["sql_generation"].sql
        if final_state.get("sql_generation", None) is not None
        else None,
        sql_result=final_state.get("query_results", None),
    )

    self._responses.append(resp)
    return resp

AgentMode

Bases: str, Enum

Mode in which the agent operates.

Source code in src/relta/agents/sql_agent.py
class AgentMode(str, Enum):
    """Mode in which the agent operates."""

    COMPLETE = "complete"  # Returns full response with SQL, data, and natural language
    SQL_ONLY = "sql_only"  # Only returns the generated SQL
    DATA_ONLY = (
        "data_only"  # Returns SQL and query results, but no natural language response
    )

SQLAgent

Default SQL Generation Agent.

Supports persisting threads. Directly generates SQL from DDL. Used internally for relta.chat.Chat

Source code in src/relta/agents/sql_agent.py
class SQLAgent:
    """Default SQL Generation Agent.

    Supports persisting threads. Directly generates SQL from DDL. Used internally for `relta.chat.Chat`
    """

    model: ClassVar[ChatOpenAI]  # populated by Client
    mini_model: ClassVar[ChatOpenAI]  # populated by Client
    _config: Configuration

    @staticmethod
    def mask_metrics(metrics: list[Metric]) -> list[Metric]:
        """Note: this is not a node!"""
        masked_metrics = [m.model_copy(deep=True) for m in metrics]
        for metric in masked_metrics:
            metric.sql_to_underlying_datasource = None
        return masked_metrics

    @staticmethod
    def select_metric(state: State, config: RunnableConfig):
        prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", METRIC_MATCH_PROMPT),
                ("placeholder", "{messages}"),
            ]
        )
        chain = prompt_template | SQLAgent.model.with_structured_output(SQLGeneration)
        masked_metrics = SQLAgent.mask_metrics(state["metrics"])

        # TODO avoid generating SQL, but for now, wipe out the SQL fields
        res: SQLGeneration = chain.invoke(
            {
                "metrics": [m.model_dump_json() for m in masked_metrics],
                "messages": state["messages"],
            }
        )
        res.sql = None
        res.sql_reasoning = None

        return {"sql_generation": res}

    @staticmethod
    def generate_sql(state: State, config: RunnableConfig):
        # examples_str = "Examples:\n{examples}\n"
        # examples_str = ""
        # for example in state["examples"]:
        # examples_str += f"### Example\n#### Prompt\n{example.prompt}\n\n#### SQL\n{example.sql}\n\n#### Explanation\n{example.explanation}\n\n"
        sql_gen = state["sql_generation"]
        if sql_gen.metric is None or sql_gen.metric.lower() == "none":
            return {"messages": [SystemMessage("No metric chosen.")]}

        # TODO filter examples by metric?
        prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", MATCH_PROMPT),
                ("system", "Metric:\n{metric}"),
                ("system", "Examples:\n{examples}\n"),
                # ("system", f"Read the prompt again: \n{MATCH_PROMPT}"),
                ("placeholder", "{messages}"),
            ]
        )
        chain = prompt_template | SQLAgent.model.with_structured_output(SQLGeneration)
        masked_metrics = SQLAgent.mask_metrics(state["metrics"])
        matched_metric = [
            m.model_dump_json() for m in masked_metrics if m.name == sql_gen.metric
        ][0]

        res: SQLGeneration = chain.invoke(
            {
                "metric": matched_metric,
                "examples": [e.model_dump_json() for e in state["examples"]],
                "messages": state["messages"],
            }
        )

        return {"sql_generation": res}

    @staticmethod
    def execute_sql(state: State, config: RunnableConfig):
        # sql_gen = SQLGeneration.model_validate_json(state["sql_generation"])
        sql_gen = state["sql_generation"]
        sql = sql_gen.sql

        if (
            sql is None
            or sql.lower() == "none"
            or sql_gen.metric.lower() == "none"
            or sql_gen.metric is None
        ):
            return {
                "messages": [SystemMessage("No SQL query generated.")],
            }

        if config["configurable"]["fuzz"]:
            # random fuzzing (fast) w/ sqlglot, numpy
            # outp = {}
            # for select in parse_one(sql).find_all(exp.Select):
            #     for projection in select.expressions:
            #         outp[projection.alias_or_name] = np.random.randint(0, 100)
            llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.0)
            prompt_template = ChatPromptTemplate.from_messages(
                [
                    ("system", FUZZ_PROMPT),
                    ("system", "{sql}"),
                    ("system", "Here is an interpretation of the DDL for the table:"),
                    ("system", "{metrics}"),
                    # ("placeholder", "{messages}"),
                ]
            )
            chain = prompt_template | llm.with_structured_output(FuzzedData)

            masked_metrics = SQLAgent.mask_metrics(state["metrics"])

            res: FuzzedData = chain.invoke(
                {
                    "sql": sql,
                    "metrics": [m.model_dump_json() for m in masked_metrics],
                    # "messages": state["messages"],
                }
            )
            outp = res.data
        else:
            try:
                outp = state["datasource"]._execute_sql(sql).fetchall()
            except Exception as e:
                return {
                    "messages": [SystemMessage(f"Error executing SQL: {str(e)}")],
                    "sql_execution_error": str(e),
                    "query_results": [],
                }

        return {
            "query_results": outp,
            "sql_execution_error": None,
            # "messages": [AIMessage(state["sql_generation"].model_dump_json())],
        }

    @staticmethod
    def repair_sql(state: State, config: RunnableConfig):
        # improve this
        # only runs if there is a failure code.
        sql_gen = state["sql_generation"]
        # sql_gen = SQLGeneration.model_validate_json(state["sql_generation"])
        # user_prompt = state["messages"][-3].content

        prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", REPAIR_PROMPT),
                ("placeholder", "{messages}"),
            ]
        )

        chain = prompt_template | SQLAgent.model.with_structured_output(SQLGeneration)
        res: SQLGeneration = chain.invoke(
            {
                "SQL": sql_gen.sql,
                "ERROR": state["sql_execution_error"],
                "TRANSIENT_DDL": state["datasource"]._get_transient_ddl(),
                "messages": state["messages"],
            }
        )

        return {"sql_generation": res, "n_retries": state["n_retries"] - 1}

    @staticmethod
    def respond(state: State, config: RunnableConfig):
        prompt_template = ChatPromptTemplate.from_messages(
            [
                ("system", RESPONSE_PROMPT),
                ("system", MATCH_PROMPT),
                ("system", "{metrics}"),
                ("placeholder", "{messages}"),
                ("system", "{query_results}"),
            ]
        )

        # Find the last HumanMessage in state messages
        user_question = None
        for msg in reversed(state["messages"]):
            if isinstance(msg, HumanMessage):
                user_question = msg.content
                break

        chain = prompt_template | SQLAgent.model
        res = chain.invoke({**state, "user_question": user_question})
        return {"messages": [AIMessage(state["sql_generation"].model_dump_json()), res]}

    def __init__(self, config: Configuration):
        self._config = config
        self.checkpointer = SqliteSaver(
            sqlite3.connect(self._config.chat_memory_path, check_same_thread=False)
        )

        graph_builder = StateGraph(State)

        graph_builder.add_node("select_metric", SQLAgent.select_metric)
        graph_builder.add_node("generate_sql", SQLAgent.generate_sql)
        graph_builder.add_node("execute_sql", SQLAgent.execute_sql)
        graph_builder.add_node("repair_sql", SQLAgent.repair_sql)
        graph_builder.add_node("response", SQLAgent.respond)

        graph_builder.add_edge(START, "select_metric")
        graph_builder.add_edge("select_metric", "generate_sql")
        # graph_builder.add_conditional_edges(
        #     "select_metric",
        #     lambda state, config: "generate_sql"
        #     if state.get("sql_generation", None) and state["sql_generation"].metric is not None
        #     else END,
        # )

        graph_builder.add_conditional_edges(
            "generate_sql",
            lambda state, config: "execute_sql"
            if AgentMode(config["configurable"]["mode"]) is not AgentMode.SQL_ONLY
            else END,
        )
        graph_builder.add_conditional_edges(
            "execute_sql",
            lambda state, config: "repair_sql"
            if state.get("sql_execution_error", None) is not None
            and state.get("n_retries", 0) > 0
            else "response"
            if AgentMode(config["configurable"]["mode"]) is not AgentMode.DATA_ONLY
            else END,
        )
        graph_builder.add_edge("repair_sql", "execute_sql")
        graph_builder.add_edge("response", END)

        self.graph = graph_builder.compile(checkpointer=self.checkpointer)

    def _get_mermaid(self):
        return self.graph.get_graph().draw_mermaid()

    def invoke(
        self,
        prompt: str,
        datasource: DataSource,
        thread_id: Optional[int] = None,
        debug: Optional[bool] = None,
        mode: AgentMode = AgentMode.COMPLETE,
        fuzz: bool = False,
        n_retries: int = 1,
    ) -> State:
        """Ask the agent a question.

        Args:
            prompt (str): Question
            datasource (DataSource): datasource to use for the question
            thread_id (int, optional): ID of the thread/chat to use for persistence. If None, runs the graph on a non-persisted thread. Defaults to None.
            debug (bool, optional): If True, runs the graph in debug mode. If None, uses the default value set in the config. Defaults to None.
            only_sql (bool, optional): If True, skips the response node and only returns the generated SQL. Defaults to False.
            fuzz (bool): Use an LLM to generate fake query results instead of executing the SQL. Defaults to False.
            n_retries (int, optional): Number of times to attempt to repair the SQL if the SQL execution fails. Defaults to 1.
        Returns:
            State: full state of the agent after the invocation
        """
        if debug is None:
            debug = self._config.debug

        state_update: State = {
            "messages": [HumanMessage(prompt)],
            "metrics": datasource.semantic_layer.metrics,
            "examples": datasource.semantic_layer.examples,
            "datasource": datasource,
            "sql_generation": None,
            "sql_execution_error": None,
            "n_retries": n_retries,
            "query_results": None,
        }

        config: AgentConfig = {
            "mode": mode,
            "fuzz": fuzz,
            "thread_id": thread_id,
        }

        return self.graph.invoke(
            state_update,
            config={"configurable": config},
            debug=debug,
        )

invoke(prompt, datasource, thread_id=None, debug=None, mode=AgentMode.COMPLETE, fuzz=False, n_retries=1)

Ask the agent a question.

Parameters:

Name Type Description Default
prompt str

Question

required
datasource DataSource

datasource to use for the question

required
thread_id int

ID of the thread/chat to use for persistence. If None, runs the graph on a non-persisted thread. Defaults to None.

None
debug bool

If True, runs the graph in debug mode. If None, uses the default value set in the config. Defaults to None.

None
only_sql bool

If True, skips the response node and only returns the generated SQL. Defaults to False.

required
fuzz bool

Use an LLM to generate fake query results instead of executing the SQL. Defaults to False.

False
n_retries int

Number of times to attempt to repair the SQL if the SQL execution fails. Defaults to 1.

1

Returns: State: full state of the agent after the invocation

Source code in src/relta/agents/sql_agent.py
def invoke(
    self,
    prompt: str,
    datasource: DataSource,
    thread_id: Optional[int] = None,
    debug: Optional[bool] = None,
    mode: AgentMode = AgentMode.COMPLETE,
    fuzz: bool = False,
    n_retries: int = 1,
) -> State:
    """Ask the agent a question.

    Args:
        prompt (str): Question
        datasource (DataSource): datasource to use for the question
        thread_id (int, optional): ID of the thread/chat to use for persistence. If None, runs the graph on a non-persisted thread. Defaults to None.
        debug (bool, optional): If True, runs the graph in debug mode. If None, uses the default value set in the config. Defaults to None.
        only_sql (bool, optional): If True, skips the response node and only returns the generated SQL. Defaults to False.
        fuzz (bool): Use an LLM to generate fake query results instead of executing the SQL. Defaults to False.
        n_retries (int, optional): Number of times to attempt to repair the SQL if the SQL execution fails. Defaults to 1.
    Returns:
        State: full state of the agent after the invocation
    """
    if debug is None:
        debug = self._config.debug

    state_update: State = {
        "messages": [HumanMessage(prompt)],
        "metrics": datasource.semantic_layer.metrics,
        "examples": datasource.semantic_layer.examples,
        "datasource": datasource,
        "sql_generation": None,
        "sql_execution_error": None,
        "n_retries": n_retries,
        "query_results": None,
    }

    config: AgentConfig = {
        "mode": mode,
        "fuzz": fuzz,
        "thread_id": thread_id,
    }

    return self.graph.invoke(
        state_update,
        config={"configurable": config},
        debug=debug,
    )

mask_metrics(metrics) staticmethod

Note: this is not a node!

Source code in src/relta/agents/sql_agent.py
@staticmethod
def mask_metrics(metrics: list[Metric]) -> list[Metric]:
    """Note: this is not a node!"""
    masked_metrics = [m.model_copy(deep=True) for m in metrics]
    for metric in masked_metrics:
        metric.sql_to_underlying_datasource = None
    return masked_metrics

SemanticLayer

Source code in src/relta/semantic/semantic_layer.py
class SemanticLayer:
    def __init__(
        self,
        datasource: "DataSource",  # noqa: F821 # type: ignore
        config: Configuration,
        load: bool = True,
        path: Optional[str] = None,
    ):
        from ..datasource import (
            DataSource,
        )  # this is to allow type hinting when writing code w/o a circular import (similar to other libraries)

        self._config = config
        self.datasource: DataSource = datasource
        self.path = (
            self._config.relta_semantic_layer_dir_path / self.datasource.name
            if path is None
            else Path(path)
        )
        self.update_reasoning: str = (
            ""  # semantic agent will write reasoning about the update here
        )
        self.feedback_responses = []
        self.metrics: list[Metric] = []
        self.examples: list[Example] = []
        # self.proposed_changes: list[Metric] = []
        makedirs(self.path, exist_ok=True)
        if load:
            self.load()

    def load(
        self,
        path: Optional[Union[str, Path]] = None,
        json: bool = True,
        metrics_to_load: Optional[list[str]] = None,
    ):
        """Load semantic layer.

        Changes to the metrics are not persisted on disk. Use `.dump()` to persist them.

        Args:
            path (Optional[Union[str, Path]], optional): Path to load the semantic layer. If None, uses `self.path`, which is populated on creation.
            json (bool, optional): Whether to additionally load the semantic layer from deprecated JSON files.
                If a metric exists in both JSON and YAML files by `metric.name`, the JSON metric is ignored. Defaults to True.
            metrics_to_load (Optional[list[str]], optional): List of metric names to load. If None, loads all metrics. Defaults to None.
        """
        logfire.info("loading semantic layer from {path}", path=str(path))
        p = Path(path) if path is not None else self.path
        metrics = []
        yaml_metric_names = set()
        examples = []

        for fpath in p.glob("*.yaml"):
            with open(fpath, "r") as f:
                data = yaml.safe_load(f)
                if fpath.name == "examples.yaml":
                    example_coll = ExampleCollection.model_validate(data)
                    examples.extend(example_coll.examples)
                else:
                    metric = Metric.model_validate(data)
                    if metrics_to_load is None or metric.name in metrics_to_load:
                        metrics.append(metric)
                        yaml_metric_names.add(metric.name)

        if json:
            for fpath in p.glob("*.json"):
                with open(fpath, "r") as f:
                    if fpath.name == "examples.json":
                        example_coll = ExampleCollection.model_validate_json(f.read())
                        examples.extend(example_coll.examples)
                    else:
                        metric = Metric.model_validate_json(f.read())
                        if (
                            metrics_to_load is None or metric.name in metrics_to_load
                        ) and metric.name not in yaml_metric_names:
                            metrics.append(metric)
        self.metrics = metrics
        self.examples = examples

    def dump(
        self,
        clear=True,
        path: Optional[Union[str, Path]] = None,
        # apply_proposals: bool = True,
    ):
        """Dumps the semantic layer, accepting any updates made to the semantic layer.

        Args:
            clear (bool): Delete all JSON/YAML files in the path for this layer. Defaults to True. See `path` attribute for details on the path.
            path (Optional[Union[str, Path]], optional): Path to dump the semantic layer. If None, uses `self.path`, which is populated on creation.
        """

        logfire.info("dumping semantic layer to file")
        p = path if path is not None else self.path

        if clear:
            for yaml_file in p.glob("*.yaml"):
                yaml_file.unlink()
            for json_file in p.glob("*.json"):
                json_file.unlink()

        for metric in self.metrics:
            with open(p / f"{metric.name}.yaml", "w+") as f:
                yaml.dump(metric.model_dump(), f, sort_keys=False)

        examples = ExampleCollection(examples=self.examples)
        with open(p / "examples.yaml", "w+") as f:
            yaml.dump(examples.model_dump(), f, sort_keys=False)

        # additionally, as dumping is "accepting" the changes, we clean up any updated state
        self.update_reasoning = ""

    def dumps(self, mode: Literal["json", "yaml"] = "yaml", **kwargs) -> str:
        """Dumps the metrics and examples to a JSON or YAML string. JSON is typically used for feeding into an agent and YAML for display.

        Args:
            **kwargs: Keyword arguments to pass to `pydantic.BaseModel.model_dump_json` or `yaml.dump` depending on mode.
            Will override, by individual key, `Configuration. default kwargs for printing JSON/YAML.

        Returns:
            str: JSON representation of the semantic layer (metrics and examples).
        """
        default_kwargs = (
            self._config.json_dumps_kwargs
            if mode == "json"
            else self._config.yaml_dumps_kwargs
        )

        default_kwargs.update(kwargs)

        ctr = SemanticLayerContainer(**vars(self))

        return (
            ctr.model_dump_json(**default_kwargs)
            if mode == "json"
            else yaml.dump(ctr.model_dump(), stream=None, **default_kwargs)
        )

    def refine(self, pr=False):
        """Refines the semantic layer based on the feedback and creates a PR with the changes. By default, sets the refined metrics, but does not persist on disk -- see `dump()` to persist.

        If `pr=True`, attempts to create a PR on the configured GitHub repo (see `Configuration`) after setting the updating the metrics in the in-memory semantic layer.
        If it is successful, returns the URL of the PR. Else, returns `None`.
        """
        logfire.info("refine semantic layer")
        llm = ChatOpenAI(model="gpt-4o", temperature=0).with_structured_output(
            RefinedMetrics
        )
        prompt = PromptTemplate.from_template(REFINE_SEMANTIC_LAYER_PROMPT)

        chain = prompt | llm

        feedback = [
            Feedback(
                sentiment=r.feedback_sentiment,
                reason=r.feedback_reason,
                selected_response=r.message,
                message_history=r.chat._get_messages(),
            )
            for r in self.feedback_responses
        ]

        result: RefinedMetrics = chain.invoke(
            {
                "METRIC_MODEL": json.dumps(
                    Metric.model_json_schema(mode="serialization"), indent=2
                ),
                "METRICS": json.dumps(
                    [metric.model_dump() for metric in self.metrics], indent=2
                ),
                "FEEDBACK": json.dumps(
                    [feedback.dict() for feedback in feedback], indent=2
                ),
                "DDL": self.datasource._get_ddl(),
            }
        )

        existing_metrics = {m.name: m for m in self.metrics}
        refined_metrics = {m.original_name: m.updated_metric for m in result.metrics}

        for name, refined_metric in refined_metrics.items():
            if name in existing_metrics:
                existing_metric = existing_metrics[name]
                print(f"Metric: {name}")

                for field in Metric.model_fields:
                    refined_value = getattr(refined_metric, field)
                    existing_value = getattr(existing_metric, field)
                    if refined_value != existing_value:
                        print(f"  {field}:")
                        print(f"    - Old: {existing_value}")
                        print(f"    + New: {refined_value}")

                        # Handle list fields like dimensions, measures, sample_questions
                        if isinstance(refined_value, list):
                            refined_set = set(str(x) for x in refined_value)
                            existing_set = set(str(x) for x in existing_value)

                            removed = existing_set - refined_set
                            added = refined_set - existing_set

                            if removed:
                                print("    Removed items:")
                                for item in removed:
                                    print(f"      - {item}")

                            if added:
                                print("    Added items:")
                                for item in added:
                                    print(f"      + {item}")

                print()
            else:
                print(f"New Metric: {name}")
                print(f"  + {refined_metric.model_dump_json(indent=2)}")
                print()

        # for metric_container in result.metrics:
        #     self.proposed_changes.append(metric_container.updated_metric)
        self.metrics = [m.updated_metric for m in result.metrics]

        if pr:
            # Create a new branch and open a PR with the refined metrics
            res = self._create_pr_with_refined_metrics(
                [update.updated_metric for update in result.metrics]
            )
            if res:
                return res
            else:
                print("Failed to create a PR.")
                return None
        else:
            print("Not creating a PR.")

    def _create_pr_with_refined_metrics(
        self, refined_metrics: list[Metric]
    ) -> Optional[str]:
        """Creates a new branch with refined metrics and opens a PR. Returns URL of PR if successful, else None."""
        try:
            g = Github(self._config.github_token)
            repo = g.get_repo(self._config.github_repo)

            # Create a new branch
            base_branch = repo.get_branch(self._config.github_base_branch)
            branch_name = f"refined-metrics-{uuid.uuid4().hex[:8]}"
            repo.create_git_ref(f"refs/heads/{branch_name}", base_branch.commit.sha)

            # Update metrics files in the new branch
            for metric in refined_metrics:
                file_path = f"{self._config.relta_semantic_layer_dir_path}/{self.datasource.name}/{metric.name}.json"
                content = metric.model_dump_json(indent=2)

                try:
                    file = repo.get_contents(file_path, ref=branch_name)
                    repo.update_file(
                        file_path,
                        f"Update {metric.name} metric",
                        content,
                        file.sha,
                        branch=branch_name,
                    )
                except GithubException:
                    repo.create_file(
                        file_path,
                        f"Add {metric.name} metric",
                        content,
                        branch=branch_name,
                    )

            # Create a pull request
            pr_title = f"Refined metrics for {self.datasource.name}"
            pr_body = "This PR contains refined metrics based on user feedback."
            pr = repo.create_pull(
                title=pr_title,
                body=pr_body,
                head=branch_name,
                base=self._config.github_base_branch,
            )

            print(
                f"Created a pull request with refined metrics: {pr_title}, id: {pr.id}, url: {pr.html_url}"
            )
            return pr.html_url
        except Exception as e:
            print(f"Error creating PR with refined metrics: {str(e)}")
            return None

    def copy(
        self,
        source: "DataSource",  # noqa: F821
        dump: bool = True,
    ):
        """Copy the semantic layer from another DataSource.

        Args:
            source (DataSource): `DataSource` object to copy the semantic layer from.
            # from_path (Optional[Union[str, Path]], optional): Path to load the semantic layer from, ignoring `source`. If None, uses `source`'s semantic layer. Defaults to None.
            dump (bool, optional): Whether to dump the semantic layer to it's path. Defaults to True.
        """
        # if from_path is None:
        self.metrics = [
            metric.model_copy(deep=True) for metric in source.semantic_layer.metrics
        ]
        self.examples = [
            example.model_copy(deep=True) for example in source.semantic_layer.examples
        ]
        # else:
        #     self.load(from_path)

        if dump:
            self.dump()

    def propose(
        self,
        queries: list[str],
        context: Optional[str] = None,
    ):
        """Proposes a new semantic layer for the given datasource and natural language queries.

        Args:
            queries (list[str]): A list of natural language queries that the semantic layer should answer.
            context (Optional[str], optional): Extra information about the datasource. Defaults to None.
        """
        logfire.span("proposing new semantic layer")
        proposed_metrics = self._generate_proposed_metrics(
            [self.datasource._get_ddl()],
            queries,
            self.datasource.name,
            context,
        )
        for m in proposed_metrics.metrics:
            if m.name.lower() == "example":
                m.name = "example_ds"
                logger.info(
                    "Renamed metric 'example' to 'example_ds' to avoid collision with few shot examples."
                )

        self.metrics = proposed_metrics.metrics
        logfire.info(
            "{num_metrics} metrics proposed in semantic layer",
            num_metrics=len(self.metrics),
        )

    def show(self):
        """Prints table of metrics."""
        raise NotImplementedError()

    def _update(self, container: SemanticLayerContainer):
        self.metrics = container.metrics
        self.examples = container.examples
        self.update_reasoning = container.update_reasoning

    @staticmethod
    def _generate_proposed_metrics(
        ddl: list[str],
        questions: list[str],
        source_name: str,
        context: Optional[str] = None,
    ) -> ProposedMetrics:
        """Generates a list of metrics for the given datasource and natural language queries.

        Args:
            ddl (list[str]): The DDL for the datasource.
            questions (list[str]): A list of natural language queries that the semantic layer should answer.
            source_name (str): The name of the datasource.
            context (str, optional): Extra information about the datasource. Defaults to None.
        """
        llm = ChatOpenAI(
            model="gpt-4o-2024-08-06", temperature=0
        ).with_structured_output(ProposedMetrics)
        prompt = PromptTemplate.from_template(SYSTEM_PROMPT_SEMANTIC_LAYER_BUILDER)

        chain = prompt | llm  # | parser
        result: ProposedMetrics = chain.invoke(
            {
                "QUESTIONS": "\n".join(questions),
                "DDL": "\n".join(ddl),
                "CONTEXT": context,
                "DATASOURCE_NAME": source_name,
            }
        )

        return result

copy(source, dump=True)

Copy the semantic layer from another DataSource.

Parameters:

Name Type Description Default
source DataSource

DataSource object to copy the semantic layer from.

required
# from_path (Optional[Union[str, Path]]

Path to load the semantic layer from, ignoring source. If None, uses source's semantic layer. Defaults to None.

required
dump bool

Whether to dump the semantic layer to it's path. Defaults to True.

True
Source code in src/relta/semantic/semantic_layer.py
def copy(
    self,
    source: "DataSource",  # noqa: F821
    dump: bool = True,
):
    """Copy the semantic layer from another DataSource.

    Args:
        source (DataSource): `DataSource` object to copy the semantic layer from.
        # from_path (Optional[Union[str, Path]], optional): Path to load the semantic layer from, ignoring `source`. If None, uses `source`'s semantic layer. Defaults to None.
        dump (bool, optional): Whether to dump the semantic layer to it's path. Defaults to True.
    """
    # if from_path is None:
    self.metrics = [
        metric.model_copy(deep=True) for metric in source.semantic_layer.metrics
    ]
    self.examples = [
        example.model_copy(deep=True) for example in source.semantic_layer.examples
    ]
    # else:
    #     self.load(from_path)

    if dump:
        self.dump()

dump(clear=True, path=None)

Dumps the semantic layer, accepting any updates made to the semantic layer.

Parameters:

Name Type Description Default
clear bool

Delete all JSON/YAML files in the path for this layer. Defaults to True. See path attribute for details on the path.

True
path Optional[Union[str, Path]]

Path to dump the semantic layer. If None, uses self.path, which is populated on creation.

None
Source code in src/relta/semantic/semantic_layer.py
def dump(
    self,
    clear=True,
    path: Optional[Union[str, Path]] = None,
    # apply_proposals: bool = True,
):
    """Dumps the semantic layer, accepting any updates made to the semantic layer.

    Args:
        clear (bool): Delete all JSON/YAML files in the path for this layer. Defaults to True. See `path` attribute for details on the path.
        path (Optional[Union[str, Path]], optional): Path to dump the semantic layer. If None, uses `self.path`, which is populated on creation.
    """

    logfire.info("dumping semantic layer to file")
    p = path if path is not None else self.path

    if clear:
        for yaml_file in p.glob("*.yaml"):
            yaml_file.unlink()
        for json_file in p.glob("*.json"):
            json_file.unlink()

    for metric in self.metrics:
        with open(p / f"{metric.name}.yaml", "w+") as f:
            yaml.dump(metric.model_dump(), f, sort_keys=False)

    examples = ExampleCollection(examples=self.examples)
    with open(p / "examples.yaml", "w+") as f:
        yaml.dump(examples.model_dump(), f, sort_keys=False)

    # additionally, as dumping is "accepting" the changes, we clean up any updated state
    self.update_reasoning = ""

dumps(mode='yaml', **kwargs)

Dumps the metrics and examples to a JSON or YAML string. JSON is typically used for feeding into an agent and YAML for display.

Parameters:

Name Type Description Default
**kwargs

Keyword arguments to pass to pydantic.BaseModel.model_dump_json or yaml.dump depending on mode.

{}

Returns:

Name Type Description
str str

JSON representation of the semantic layer (metrics and examples).

Source code in src/relta/semantic/semantic_layer.py
def dumps(self, mode: Literal["json", "yaml"] = "yaml", **kwargs) -> str:
    """Dumps the metrics and examples to a JSON or YAML string. JSON is typically used for feeding into an agent and YAML for display.

    Args:
        **kwargs: Keyword arguments to pass to `pydantic.BaseModel.model_dump_json` or `yaml.dump` depending on mode.
        Will override, by individual key, `Configuration. default kwargs for printing JSON/YAML.

    Returns:
        str: JSON representation of the semantic layer (metrics and examples).
    """
    default_kwargs = (
        self._config.json_dumps_kwargs
        if mode == "json"
        else self._config.yaml_dumps_kwargs
    )

    default_kwargs.update(kwargs)

    ctr = SemanticLayerContainer(**vars(self))

    return (
        ctr.model_dump_json(**default_kwargs)
        if mode == "json"
        else yaml.dump(ctr.model_dump(), stream=None, **default_kwargs)
    )

load(path=None, json=True, metrics_to_load=None)

Load semantic layer.

Changes to the metrics are not persisted on disk. Use .dump() to persist them.

Parameters:

Name Type Description Default
path Optional[Union[str, Path]]

Path to load the semantic layer. If None, uses self.path, which is populated on creation.

None
json bool

Whether to additionally load the semantic layer from deprecated JSON files. If a metric exists in both JSON and YAML files by metric.name, the JSON metric is ignored. Defaults to True.

True
metrics_to_load Optional[list[str]]

List of metric names to load. If None, loads all metrics. Defaults to None.

None
Source code in src/relta/semantic/semantic_layer.py
def load(
    self,
    path: Optional[Union[str, Path]] = None,
    json: bool = True,
    metrics_to_load: Optional[list[str]] = None,
):
    """Load semantic layer.

    Changes to the metrics are not persisted on disk. Use `.dump()` to persist them.

    Args:
        path (Optional[Union[str, Path]], optional): Path to load the semantic layer. If None, uses `self.path`, which is populated on creation.
        json (bool, optional): Whether to additionally load the semantic layer from deprecated JSON files.
            If a metric exists in both JSON and YAML files by `metric.name`, the JSON metric is ignored. Defaults to True.
        metrics_to_load (Optional[list[str]], optional): List of metric names to load. If None, loads all metrics. Defaults to None.
    """
    logfire.info("loading semantic layer from {path}", path=str(path))
    p = Path(path) if path is not None else self.path
    metrics = []
    yaml_metric_names = set()
    examples = []

    for fpath in p.glob("*.yaml"):
        with open(fpath, "r") as f:
            data = yaml.safe_load(f)
            if fpath.name == "examples.yaml":
                example_coll = ExampleCollection.model_validate(data)
                examples.extend(example_coll.examples)
            else:
                metric = Metric.model_validate(data)
                if metrics_to_load is None or metric.name in metrics_to_load:
                    metrics.append(metric)
                    yaml_metric_names.add(metric.name)

    if json:
        for fpath in p.glob("*.json"):
            with open(fpath, "r") as f:
                if fpath.name == "examples.json":
                    example_coll = ExampleCollection.model_validate_json(f.read())
                    examples.extend(example_coll.examples)
                else:
                    metric = Metric.model_validate_json(f.read())
                    if (
                        metrics_to_load is None or metric.name in metrics_to_load
                    ) and metric.name not in yaml_metric_names:
                        metrics.append(metric)
    self.metrics = metrics
    self.examples = examples

propose(queries, context=None)

Proposes a new semantic layer for the given datasource and natural language queries.

Parameters:

Name Type Description Default
queries list[str]

A list of natural language queries that the semantic layer should answer.

required
context Optional[str]

Extra information about the datasource. Defaults to None.

None
Source code in src/relta/semantic/semantic_layer.py
def propose(
    self,
    queries: list[str],
    context: Optional[str] = None,
):
    """Proposes a new semantic layer for the given datasource and natural language queries.

    Args:
        queries (list[str]): A list of natural language queries that the semantic layer should answer.
        context (Optional[str], optional): Extra information about the datasource. Defaults to None.
    """
    logfire.span("proposing new semantic layer")
    proposed_metrics = self._generate_proposed_metrics(
        [self.datasource._get_ddl()],
        queries,
        self.datasource.name,
        context,
    )
    for m in proposed_metrics.metrics:
        if m.name.lower() == "example":
            m.name = "example_ds"
            logger.info(
                "Renamed metric 'example' to 'example_ds' to avoid collision with few shot examples."
            )

    self.metrics = proposed_metrics.metrics
    logfire.info(
        "{num_metrics} metrics proposed in semantic layer",
        num_metrics=len(self.metrics),
    )

refine(pr=False)

Refines the semantic layer based on the feedback and creates a PR with the changes. By default, sets the refined metrics, but does not persist on disk -- see dump() to persist.

If pr=True, attempts to create a PR on the configured GitHub repo (see Configuration) after setting the updating the metrics in the in-memory semantic layer. If it is successful, returns the URL of the PR. Else, returns None.

Source code in src/relta/semantic/semantic_layer.py
def refine(self, pr=False):
    """Refines the semantic layer based on the feedback and creates a PR with the changes. By default, sets the refined metrics, but does not persist on disk -- see `dump()` to persist.

    If `pr=True`, attempts to create a PR on the configured GitHub repo (see `Configuration`) after setting the updating the metrics in the in-memory semantic layer.
    If it is successful, returns the URL of the PR. Else, returns `None`.
    """
    logfire.info("refine semantic layer")
    llm = ChatOpenAI(model="gpt-4o", temperature=0).with_structured_output(
        RefinedMetrics
    )
    prompt = PromptTemplate.from_template(REFINE_SEMANTIC_LAYER_PROMPT)

    chain = prompt | llm

    feedback = [
        Feedback(
            sentiment=r.feedback_sentiment,
            reason=r.feedback_reason,
            selected_response=r.message,
            message_history=r.chat._get_messages(),
        )
        for r in self.feedback_responses
    ]

    result: RefinedMetrics = chain.invoke(
        {
            "METRIC_MODEL": json.dumps(
                Metric.model_json_schema(mode="serialization"), indent=2
            ),
            "METRICS": json.dumps(
                [metric.model_dump() for metric in self.metrics], indent=2
            ),
            "FEEDBACK": json.dumps(
                [feedback.dict() for feedback in feedback], indent=2
            ),
            "DDL": self.datasource._get_ddl(),
        }
    )

    existing_metrics = {m.name: m for m in self.metrics}
    refined_metrics = {m.original_name: m.updated_metric for m in result.metrics}

    for name, refined_metric in refined_metrics.items():
        if name in existing_metrics:
            existing_metric = existing_metrics[name]
            print(f"Metric: {name}")

            for field in Metric.model_fields:
                refined_value = getattr(refined_metric, field)
                existing_value = getattr(existing_metric, field)
                if refined_value != existing_value:
                    print(f"  {field}:")
                    print(f"    - Old: {existing_value}")
                    print(f"    + New: {refined_value}")

                    # Handle list fields like dimensions, measures, sample_questions
                    if isinstance(refined_value, list):
                        refined_set = set(str(x) for x in refined_value)
                        existing_set = set(str(x) for x in existing_value)

                        removed = existing_set - refined_set
                        added = refined_set - existing_set

                        if removed:
                            print("    Removed items:")
                            for item in removed:
                                print(f"      - {item}")

                        if added:
                            print("    Added items:")
                            for item in added:
                                print(f"      + {item}")

            print()
        else:
            print(f"New Metric: {name}")
            print(f"  + {refined_metric.model_dump_json(indent=2)}")
            print()

    # for metric_container in result.metrics:
    #     self.proposed_changes.append(metric_container.updated_metric)
    self.metrics = [m.updated_metric for m in result.metrics]

    if pr:
        # Create a new branch and open a PR with the refined metrics
        res = self._create_pr_with_refined_metrics(
            [update.updated_metric for update in result.metrics]
        )
        if res:
            return res
        else:
            print("Failed to create a PR.")
            return None
    else:
        print("Not creating a PR.")

show()

Prints table of metrics.

Source code in src/relta/semantic/semantic_layer.py
def show(self):
    """Prints table of metrics."""
    raise NotImplementedError()

Dimension

Bases: BaseModel

A dimension for a metric.

This is similar to a dimension in LookML -- it can be used to group or filter data in a metric. For example, a dimension could be a column in a table or a calculation based on columns in a table.

Source code in src/relta/semantic/base.py
class Dimension(BaseModel):
    """A dimension for a metric.

    This is similar to a dimension in LookML -- it can be used to group or filter data in a metric.
    For example, a dimension could be a column in a table or a calculation based on columns in a table.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    name: str = Field(
        description="A short name of the dimension. Must be unique within the metric. It can be the same as the column name."
    )

    description: str = Field(
        description="A longer description of the dimension. Keep it within 3 sentences."
    )

    categories: Optional[list] = Field(
        description="The categories of the dimension. This is used to create a list of values for the dimension. This should not be filled out when setting up the semantic layer."
    )

    skip_categorical_load: bool = Field(
        default=False,
        description="Controls whether to load the categorical values for this dimension. Defaults to False. To be used if the data is large text."
    )

    dtype: Optional[str] = Field(description="The data type of the dimension.")

ExampleCollection

Bases: BaseModel

Used for persistence (load/dump) of examples.

Source code in src/relta/semantic/base.py
class ExampleCollection(BaseModel):
    """Used for persistence (load/dump) of examples."""

    examples: list[Example] = Field(description="A collection of examples.")

Measure

Bases: BaseModel

A measure for a metric.

This is similar to a measure in LookML -- it is an aggregate operation on some dimensions. For example, a measure could be the sum of a column or the average of columnA * columnB.

Source code in src/relta/semantic/base.py
class Measure(BaseModel):
    """A measure for a metric.

    This is similar to a measure in LookML -- it is an aggregate operation on some dimensions.
    For example, a measure could be the sum of a column or the average of columnA * columnB.
    """

    name: str = Field(
        description="A short name of the measure. Must be unique within the metric."
    )
    description: str = Field(
        description="A longer description of the measure. Keep it within 3 sentences."
    )
    expr: str = Field(
        description="The SQL expression for this aggregate operation. This must be a valid PostgreSQL expression. Any columns used should also be defined as dimensions."
    )