Skip to content

Commit

Permalink
zeta: Report Fireworks request data to Snowflake (#22973)
Browse files Browse the repository at this point in the history
Release Notes:

- N/A

---------

Co-authored-by: Antonio Scandurra <[email protected]>
Co-authored-by: Conrad <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2025
1 parent 3d80b21 commit 1fcc9b3
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 3 deletions.
12 changes: 12 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ members = [
"crates/feedback",
"crates/file_finder",
"crates/file_icons",
"crates/fireworks",
"crates/fs",
"crates/fsevent",
"crates/fuzzy",
Expand Down Expand Up @@ -222,6 +223,7 @@ feature_flags = { path = "crates/feature_flags" }
feedback = { path = "crates/feedback" }
file_finder = { path = "crates/file_finder" }
file_icons = { path = "crates/file_icons" }
fireworks = { path = "crates/fireworks" }
fs = { path = "crates/fs" }
fsevent = { path = "crates/fsevent" }
fuzzy = { path = "crates/fuzzy" }
Expand Down
1 change: 1 addition & 0 deletions crates/collab/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ collections.workspace = true
dashmap.workspace = true
derive_more.workspace = true
envy = "0.4.2"
fireworks.workspace = true
futures.workspace = true
google_ai.workspace = true
hex.workspace = true
Expand Down
31 changes: 28 additions & 3 deletions crates/collab/src/llm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,23 +470,48 @@ async fn predict_edits(
.replace("<outline>", &outline_prefix)
.replace("<events>", &params.input_events)
.replace("<excerpt>", &params.input_excerpt);
let mut response = open_ai::complete_text(
let mut response = fireworks::complete(
&state.http_client,
api_url,
api_key,
open_ai::CompletionRequest {
fireworks::CompletionRequest {
model: model.to_string(),
prompt: prompt.clone(),
max_tokens: 2048,
temperature: 0.,
prediction: Some(open_ai::Prediction::Content {
prediction: Some(fireworks::Prediction::Content {
content: params.input_excerpt,
}),
rewrite_speculation: Some(true),
},
)
.await?;

state.executor.spawn_detached({
let kinesis_client = state.kinesis_client.clone();
let kinesis_stream = state.config.kinesis_stream.clone();
let headers = response.headers.clone();
let model = model.clone();

async move {
SnowflakeRow::new(
"Fireworks Completion Requested",
claims.metrics_id,
claims.is_staff,
claims.system_id.clone(),
json!({
"model": model.to_string(),
"headers": headers,
}),
)
.write(&kinesis_client, &kinesis_stream)
.await
.log_err();
}
});

let choice = response
.completion
.choices
.pop()
.context("no output from completion response")?;
Expand Down
19 changes: 19 additions & 0 deletions crates/fireworks/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[package]
name = "fireworks"
version = "0.1.0"
edition = "2021"
publish = false
license = "GPL-3.0-or-later"

[lints]
workspace = true

[lib]
path = "src/fireworks.rs"

[dependencies]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
serde.workspace = true
serde_json.workspace = true
1 change: 1 addition & 0 deletions crates/fireworks/LICENSE-GPL
173 changes: 173 additions & 0 deletions crates/fireworks/src/fireworks.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
use anyhow::{anyhow, Result};
use futures::AsyncReadExt;
use http_client::{http::HeaderMap, AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};

pub const FIREWORKS_API_URL: &str = "https://api.openai.com/v1";

#[derive(Debug, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
pub max_tokens: u32,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prediction: Option<Prediction>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub rewrite_speculation: Option<bool>,
}

#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum Prediction {
Content { content: String },
}

#[derive(Debug)]
pub struct Response {
pub completion: CompletionResponse,
pub headers: Headers,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct CompletionChoice {
pub text: String,
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}

#[derive(Debug, Clone, Default, Serialize)]
pub struct Headers {
pub server_processing_time: Option<f64>,
pub request_id: Option<String>,
pub prompt_tokens: Option<u32>,
pub speculation_generated_tokens: Option<u32>,
pub cached_prompt_tokens: Option<u32>,
pub backend_host: Option<String>,
pub num_concurrent_requests: Option<u32>,
pub deployment: Option<String>,
pub tokenizer_queue_duration: Option<f64>,
pub tokenizer_duration: Option<f64>,
pub prefill_queue_duration: Option<f64>,
pub prefill_duration: Option<f64>,
pub generation_queue_duration: Option<f64>,
}

impl Headers {
pub fn parse(headers: &HeaderMap) -> Self {
Headers {
request_id: headers
.get("x-request-id")
.and_then(|v| v.to_str().ok())
.map(String::from),
server_processing_time: headers
.get("fireworks-server-processing-time")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prompt_tokens: headers
.get("fireworks-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
speculation_generated_tokens: headers
.get("fireworks-speculation-generated-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
cached_prompt_tokens: headers
.get("fireworks-cached-prompt-tokens")
.and_then(|v| v.to_str().ok()?.parse().ok()),
backend_host: headers
.get("fireworks-backend-host")
.and_then(|v| v.to_str().ok())
.map(String::from),
num_concurrent_requests: headers
.get("fireworks-num-concurrent-requests")
.and_then(|v| v.to_str().ok()?.parse().ok()),
deployment: headers
.get("fireworks-deployment")
.and_then(|v| v.to_str().ok())
.map(String::from),
tokenizer_queue_duration: headers
.get("fireworks-tokenizer-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
tokenizer_duration: headers
.get("fireworks-tokenizer-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prefill_queue_duration: headers
.get("fireworks-prefill-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
prefill_duration: headers
.get("fireworks-prefill-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
generation_queue_duration: headers
.get("fireworks-generation-queue-duration")
.and_then(|v| v.to_str().ok()?.parse().ok()),
}
}
}

pub async fn complete(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CompletionRequest,
) -> Result<Response> {
let uri = format!("{api_url}/completions");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", api_key));

let request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(request).await?;

if response.status().is_success() {
let headers = Headers::parse(response.headers());

let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;

Ok(Response {
completion: serde_json::from_str(&body)?,
headers,
})
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;

#[derive(Deserialize)]
struct FireworksResponse {
error: FireworksError,
}

#[derive(Deserialize)]
struct FireworksError {
message: String,
}

match serde_json::from_str::<FireworksResponse>(&body) {
Ok(response) if !response.error.message.is_empty() => Err(anyhow!(
"Failed to connect to Fireworks API: {}",
response.error.message,
)),

_ => Err(anyhow!(
"Failed to connect to Fireworks API: {} {}",
response.status(),
body,
)),
}
}
}

0 comments on commit 1fcc9b3

Please sign in to comment.